From 8c6b8249aed97f18a6de471cb5544689befd58e4 Mon Sep 17 00:00:00 2001 From: Dan Padilha Date: Thu, 3 Dec 2020 18:01:00 +0900 Subject: [PATCH 01/26] Add example of failing precompilation test. --- test/precompile_test.jl | 21 +++++++++++++++++++++ test/precompile_test/ODEPrecompileTest.jl | 22 ++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 44 insertions(+) create mode 100644 test/precompile_test.jl create mode 100644 test/precompile_test/ODEPrecompileTest.jl diff --git a/test/precompile_test.jl b/test/precompile_test.jl new file mode 100644 index 0000000000..1c4d504a71 --- /dev/null +++ b/test/precompile_test.jl @@ -0,0 +1,21 @@ +using Test +using ModelingToolkit + +# Test that the precompiled ODE system works +push!(LOAD_PATH, joinpath(@__DIR__, "precompile_test")) +using ODEPrecompileTest + +du = zeros(3) +u = collect(1:3) +p = collect(4:6) + +# This case does not work, because the function gets defined in ModelingToolkit +# instead of in the compiled module! +@test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_iip).parameters[2]) == ModelingToolkit +@test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_oop).parameters[2]) == ModelingToolkit +@test_throws KeyError ODEPrecompileTest.f_bad(du, u, p, 0.1) + +# This case works, because the function gets defined in the compiled module. +# @test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_iip).parameters[2]) == ODEPrecompileTest +# @test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_oop).parameters[2]) == ODEPrecompileTest +# @test ODEPrecompileTest.f_good(du, u, p, 0.1) \ No newline at end of file diff --git a/test/precompile_test/ODEPrecompileTest.jl b/test/precompile_test/ODEPrecompileTest.jl new file mode 100644 index 0000000000..ef5146d228 --- /dev/null +++ b/test/precompile_test/ODEPrecompileTest.jl @@ -0,0 +1,22 @@ +module ODEPrecompileTest + using ModelingToolkit + + function system(; kwargs...) + # Define some variables + @parameters t σ ρ β + @variables x(t) y(t) z(t) + @derivatives D'~t + + # Define a differential equation + eqs = [D(x) ~ σ*(y-x), + D(y) ~ x*(ρ-z)-y, + D(z) ~ x*y - β*z] + + de = ODESystem(eqs) + return ODEFunction(de, [x,y,z], [σ,ρ,β]; kwargs...) + end + + # Build a simple ODEFunction as part of the module's precompilation. + const f_bad = system() + # const f_good = system(; eval_module=@__MODULE__) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 5ad02512c9..0049c05ec1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,3 +33,4 @@ using SafeTestsets, Test println("Last test requires gcc available in the path!") @safetestset "C Compilation Test" begin include("ccompile.jl") end @safetestset "Latexify recipes Test" begin include("latexify.jl") end +@safetestset "Precompiled Modules Test" begin include("precompile_test.jl") end From 1b14febb4c642268cf9a2d448af24ee2503fd88b Mon Sep 17 00:00:00 2001 From: Dan Padilha Date: Thu, 3 Dec 2020 18:16:11 +0900 Subject: [PATCH 02/26] Add "eval_module" option to specify which module to use for RGF caching. --- src/build_function.jl | 5 +++-- src/systems/diffeqs/abstractodesystem.jl | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 752f7e69ac..2d8787e23e 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -100,6 +100,7 @@ end # Scalar output function _build_function(target::JuliaTarget, op, args...; conv = toexpr, expression = Val{true}, + expression_module = @__MODULE__, checkbounds = false, linenumbers = true, headerfun=addheader) @@ -127,12 +128,12 @@ function _build_function(target::JuliaTarget, op, args...; if expression == Val{true} return ModelingToolkit.inject_registered_module_functions(oop_ex) else - _build_and_inject_function(@__MODULE__, oop_ex) + _build_and_inject_function(expression_module, oop_ex) end end function _build_and_inject_function(mod::Module, ex) - @RuntimeGeneratedFunction(ModelingToolkit.inject_registered_module_functions(ex)) + @RuntimeGeneratedFunction(mod, ModelingToolkit.inject_registered_module_functions(ex)) end # Detect heterogeneous element types of "arrays of matrices/sparce matrices" diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 388ad9c3de..216a5b7c3b 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -124,11 +124,12 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), version = nothing, tgrad=false, jac = false, eval_expression = true, + eval_module = @__MODULE__, sparse = false, simplify = true, kwargs...) where {iip} f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...) - f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in f_gen) : f_gen + f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen f(u,p,t) = f_oop(u,p,t) f(du,u,p,t) = f_iip(du,u,p,t) @@ -136,7 +137,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), tgrad_gen = generate_tgrad(sys, dvs, ps; simplify=simplify, expression=Val{eval_expression}, kwargs...) - tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in tgrad_gen) : tgrad_gen + tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in tgrad_gen) : tgrad_gen _tgrad(u,p,t) = tgrad_oop(u,p,t) _tgrad(J,u,p,t) = tgrad_iip(J,u,p,t) else @@ -147,7 +148,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), jac_gen = generate_jacobian(sys, dvs, ps; simplify=simplify, sparse = sparse, expression=Val{eval_expression}, kwargs...) - jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in jac_gen) : jac_gen + jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) : jac_gen _jac(u,p,t) = jac_oop(u,p,t) _jac(J,u,p,t) = jac_iip(J,u,p,t) else From 87ec50a135e811e952579610f9b3e37ea7a56038 Mon Sep 17 00:00:00 2001 From: Dan Padilha Date: Thu, 3 Dec 2020 18:16:41 +0900 Subject: [PATCH 03/26] Add tests for "f_good" showing that "eval_module" works. --- test/precompile_test.jl | 13 ++++++------- test/precompile_test/ODEPrecompileTest.jl | 10 ++++++++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/test/precompile_test.jl b/test/precompile_test.jl index 1c4d504a71..d91a85133a 100644 --- a/test/precompile_test.jl +++ b/test/precompile_test.jl @@ -5,17 +5,16 @@ using ModelingToolkit push!(LOAD_PATH, joinpath(@__DIR__, "precompile_test")) using ODEPrecompileTest -du = zeros(3) u = collect(1:3) p = collect(4:6) -# This case does not work, because the function gets defined in ModelingToolkit +# This case does not work, because "f_bad" gets defined in ModelingToolkit # instead of in the compiled module! @test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_iip).parameters[2]) == ModelingToolkit @test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_oop).parameters[2]) == ModelingToolkit -@test_throws KeyError ODEPrecompileTest.f_bad(du, u, p, 0.1) +@test_throws KeyError ODEPrecompileTest.f_bad(u, p, 0.1) -# This case works, because the function gets defined in the compiled module. -# @test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_iip).parameters[2]) == ODEPrecompileTest -# @test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_oop).parameters[2]) == ODEPrecompileTest -# @test ODEPrecompileTest.f_good(du, u, p, 0.1) \ No newline at end of file +# This case works, because "f_good" gets defined in the precompiled module. +@test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_iip).parameters[2]) == ODEPrecompileTest +@test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_oop).parameters[2]) == ODEPrecompileTest +@test ODEPrecompileTest.f_good(u, p, 0.1) == [4, 0, -16] \ No newline at end of file diff --git a/test/precompile_test/ODEPrecompileTest.jl b/test/precompile_test/ODEPrecompileTest.jl index ef5146d228..ed9ccda7fc 100644 --- a/test/precompile_test/ODEPrecompileTest.jl +++ b/test/precompile_test/ODEPrecompileTest.jl @@ -16,7 +16,13 @@ module ODEPrecompileTest return ODEFunction(de, [x,y,z], [σ,ρ,β]; kwargs...) end - # Build a simple ODEFunction as part of the module's precompilation. + # Build an ODEFunction as part of the module's precompilation. This case + # will not work, because the generated RGFs will be put into + # ModelingToolkit's RGF cache. const f_bad = system() - # const f_good = system(; eval_module=@__MODULE__) + + # This case will work, because it will be put into our own module's cache. + using RuntimeGeneratedFunctions + RuntimeGeneratedFunctions.init(@__MODULE__) + const f_good = system(; eval_module=@__MODULE__) end \ No newline at end of file From c9a183ebba7cf16685350e3ad100a7bd18090e2a Mon Sep 17 00:00:00 2001 From: Dan Padilha Date: Thu, 3 Dec 2020 18:37:54 +0900 Subject: [PATCH 04/26] Add expression_module as well. --- src/systems/diffeqs/abstractodesystem.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 216a5b7c3b..a38b2a343d 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -128,7 +128,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), sparse = false, simplify = true, kwargs...) where {iip} - f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...) + f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, expression_module=eval_module, kwargs...) f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen f(u,p,t) = f_oop(u,p,t) f(du,u,p,t) = f_iip(du,u,p,t) @@ -136,7 +136,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), if tgrad tgrad_gen = generate_tgrad(sys, dvs, ps; simplify=simplify, - expression=Val{eval_expression}, kwargs...) + expression=Val{eval_expression}, expression_module=eval_module, kwargs...) tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in tgrad_gen) : tgrad_gen _tgrad(u,p,t) = tgrad_oop(u,p,t) _tgrad(J,u,p,t) = tgrad_iip(J,u,p,t) @@ -147,7 +147,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), if jac jac_gen = generate_jacobian(sys, dvs, ps; simplify=simplify, sparse = sparse, - expression=Val{eval_expression}, kwargs...) + expression=Val{eval_expression}, expression_module=eval_module, kwargs...) jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) : jac_gen _jac(u,p,t) = jac_oop(u,p,t) _jac(J,u,p,t) = jac_iip(J,u,p,t) From 5cda7f92f282075a7070213d3cc42a2c554f734b Mon Sep 17 00:00:00 2001 From: Dan Padilha Date: Thu, 3 Dec 2020 18:46:19 +0900 Subject: [PATCH 05/26] Also fix eval_expression=false case. --- src/build_function.jl | 3 ++- test/precompile_test.jl | 8 +++++++- test/precompile_test/ODEPrecompileTest.jl | 4 ++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 2d8787e23e..dc6ff3642b 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -219,6 +219,7 @@ Special Keyword Argumnets: """ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; conv = toexpr, expression = Val{true}, + expression_module = @__MODULE__, checkbounds = false, linenumbers = false, multithread=nothing, headerfun = addheader, outputidxs=nothing, @@ -458,7 +459,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; if expression == Val{true} return ModelingToolkit.inject_registered_module_functions(oop_ex), ModelingToolkit.inject_registered_module_functions(iip_ex) else - return _build_and_inject_function(@__MODULE__, oop_ex), _build_and_inject_function(@__MODULE__, iip_ex) + return _build_and_inject_function(expression_module, oop_ex), _build_and_inject_function(expression_module, iip_ex) end end diff --git a/test/precompile_test.jl b/test/precompile_test.jl index d91a85133a..61abebe55e 100644 --- a/test/precompile_test.jl +++ b/test/precompile_test.jl @@ -12,9 +12,15 @@ p = collect(4:6) # instead of in the compiled module! @test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_iip).parameters[2]) == ModelingToolkit @test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_oop).parameters[2]) == ModelingToolkit +@test parentmodule(typeof(ODEPrecompileTest.f_noeval_bad.f.f_iip).parameters[2]) == ModelingToolkit +@test parentmodule(typeof(ODEPrecompileTest.f_noeval_bad.f.f_oop).parameters[2]) == ModelingToolkit @test_throws KeyError ODEPrecompileTest.f_bad(u, p, 0.1) +@test_throws KeyError ODEPrecompileTest.f_noeval_bad(u, p, 0.1) # This case works, because "f_good" gets defined in the precompiled module. @test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_iip).parameters[2]) == ODEPrecompileTest @test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_oop).parameters[2]) == ODEPrecompileTest -@test ODEPrecompileTest.f_good(u, p, 0.1) == [4, 0, -16] \ No newline at end of file +@test parentmodule(typeof(ODEPrecompileTest.f_noeval_good.f.f_iip).parameters[2]) == ODEPrecompileTest +@test parentmodule(typeof(ODEPrecompileTest.f_noeval_good.f.f_oop).parameters[2]) == ODEPrecompileTest +@test ODEPrecompileTest.f_good(u, p, 0.1) == [4, 0, -16] +@test ODEPrecompileTest.f_noeval_good(u, p, 0.1) == [4, 0, -16] \ No newline at end of file diff --git a/test/precompile_test/ODEPrecompileTest.jl b/test/precompile_test/ODEPrecompileTest.jl index ed9ccda7fc..453cb0d774 100644 --- a/test/precompile_test/ODEPrecompileTest.jl +++ b/test/precompile_test/ODEPrecompileTest.jl @@ -25,4 +25,8 @@ module ODEPrecompileTest using RuntimeGeneratedFunctions RuntimeGeneratedFunctions.init(@__MODULE__) const f_good = system(; eval_module=@__MODULE__) + + # Also test that eval_expression=false works + const f_noeval_bad = system(; eval_expression=false) + const f_noeval_good = system(; eval_expression=false, eval_module=@__MODULE__) end \ No newline at end of file From fc64aa2d8b250fd91babf6467961b9d702a2d828 Mon Sep 17 00:00:00 2001 From: Dan Padilha Date: Fri, 18 Dec 2020 14:58:48 +0900 Subject: [PATCH 06/26] Updated syntax to use new constructor changes from RuntimeGeneratedFunctions.jl#20 --- src/build_function.jl | 2 +- src/systems/diffeqs/abstractodesystem.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index dc6ff3642b..049b2f0218 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -133,7 +133,7 @@ function _build_function(target::JuliaTarget, op, args...; end function _build_and_inject_function(mod::Module, ex) - @RuntimeGeneratedFunction(mod, ModelingToolkit.inject_registered_module_functions(ex)) + RuntimeGeneratedFunction(mod, ModelingToolkit.inject_registered_module_functions(ex)) end # Detect heterogeneous element types of "arrays of matrices/sparce matrices" diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index a38b2a343d..3c4789f094 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -129,7 +129,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), kwargs...) where {iip} f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, expression_module=eval_module, kwargs...) - f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen + f_oop,f_iip = eval_expression ? (RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen f(u,p,t) = f_oop(u,p,t) f(du,u,p,t) = f_iip(du,u,p,t) @@ -137,7 +137,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), tgrad_gen = generate_tgrad(sys, dvs, ps; simplify=simplify, expression=Val{eval_expression}, expression_module=eval_module, kwargs...) - tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in tgrad_gen) : tgrad_gen + tgrad_oop,tgrad_iip = eval_expression ? (RuntimeGeneratedFunction(eval_module, ex) for ex in tgrad_gen) : tgrad_gen _tgrad(u,p,t) = tgrad_oop(u,p,t) _tgrad(J,u,p,t) = tgrad_iip(J,u,p,t) else @@ -148,7 +148,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), jac_gen = generate_jacobian(sys, dvs, ps; simplify=simplify, sparse = sparse, expression=Val{eval_expression}, expression_module=eval_module, kwargs...) - jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) : jac_gen + jac_oop,jac_iip = eval_expression ? (RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) : jac_gen _jac(u,p,t) = jac_oop(u,p,t) _jac(J,u,p,t) = jac_iip(J,u,p,t) else From 6705d6adde0fe104fce11e21beda26810f0231f6 Mon Sep 17 00:00:00 2001 From: Dan Padilha Date: Fri, 18 Dec 2020 15:10:07 +0900 Subject: [PATCH 07/26] Update version requirement for RuntimeGeneratedFunctions. --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ff8d8ce886..9dd107c933 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ MacroTools = "0.5" NaNMath = "0.3" RecursiveArrayTools = "2.3" Requires = "1.0" -RuntimeGeneratedFunctions = "0.4" +RuntimeGeneratedFunctions = "0.4.3" SafeTestsets = "0.0.1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0" StaticArrays = "0.10, 0.11, 0.12, 1.0" From 73390fb98d8b74418046d1a3679b46fc32fbecd6 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 4 Jan 2021 08:10:17 -0500 Subject: [PATCH 08/26] Update abstractodesystem.jl --- src/systems/diffeqs/abstractodesystem.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 3c4789f094..a38b2a343d 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -129,7 +129,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), kwargs...) where {iip} f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, expression_module=eval_module, kwargs...) - f_oop,f_iip = eval_expression ? (RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen + f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen f(u,p,t) = f_oop(u,p,t) f(du,u,p,t) = f_iip(du,u,p,t) @@ -137,7 +137,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), tgrad_gen = generate_tgrad(sys, dvs, ps; simplify=simplify, expression=Val{eval_expression}, expression_module=eval_module, kwargs...) - tgrad_oop,tgrad_iip = eval_expression ? (RuntimeGeneratedFunction(eval_module, ex) for ex in tgrad_gen) : tgrad_gen + tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in tgrad_gen) : tgrad_gen _tgrad(u,p,t) = tgrad_oop(u,p,t) _tgrad(J,u,p,t) = tgrad_iip(J,u,p,t) else @@ -148,7 +148,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), jac_gen = generate_jacobian(sys, dvs, ps; simplify=simplify, sparse = sparse, expression=Val{eval_expression}, expression_module=eval_module, kwargs...) - jac_oop,jac_iip = eval_expression ? (RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) : jac_gen + jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) : jac_gen _jac(u,p,t) = jac_oop(u,p,t) _jac(J,u,p,t) = jac_iip(J,u,p,t) else From 8170e4d5eabbc91d56f1cab6647a19df2d51599b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 4 Jan 2021 08:12:08 -0500 Subject: [PATCH 09/26] Update build_function.jl --- src/build_function.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/build_function.jl b/src/build_function.jl index 049b2f0218..dc6ff3642b 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -133,7 +133,7 @@ function _build_function(target::JuliaTarget, op, args...; end function _build_and_inject_function(mod::Module, ex) - RuntimeGeneratedFunction(mod, ModelingToolkit.inject_registered_module_functions(ex)) + @RuntimeGeneratedFunction(mod, ModelingToolkit.inject_registered_module_functions(ex)) end # Detect heterogeneous element types of "arrays of matrices/sparce matrices" From ff6358d835bfe9b2fbbb7bac695e5bb3e9fe1c8d Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 31 Jan 2021 19:20:21 -0500 Subject: [PATCH 10/26] redo scalar build_function --- Project.toml | 2 +- src/ModelingToolkit.jl | 2 ++ src/build_function.jl | 42 ++++++++++-------------------------------- src/direct.jl | 28 ---------------------------- 4 files changed, 13 insertions(+), 61 deletions(-) diff --git a/Project.toml b/Project.toml index e65cadd8cb..f0abd4951b 100644 --- a/Project.toml +++ b/Project.toml @@ -54,7 +54,7 @@ SafeTestsets = "0.0.1" Setfield = "0.7" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0" StaticArrays = "0.10, 0.11, 0.12, 1.0" -SymbolicUtils = "0.7.4" +SymbolicUtils = "0.8" TreeViews = "0.3" UnPack = "0.1, 1.0" Unitful = "1.1" diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 8b9e76f458..cde5ed721f 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -27,6 +27,8 @@ import SymbolicUtils: Term, Add, Mul, Pow, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm, promote_symtype +import SymbolicUtils.Code: toexpr + import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint using LinearAlgebra: LU, BlasInt diff --git a/src/build_function.jl b/src/build_function.jl index 65fc04dcdf..83da819229 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -121,48 +121,26 @@ function _build_function(target::JuliaTarget, op::Let, args...; conv=toexpr, kw. end # Scalar output + function _build_function(target::JuliaTarget, op, args...; - conv = toexpr, expression = Val{true}, + conv = toexpr, + expression = Val{true}, checkbounds = false, - inner_let = nothing, - linenumbers = true, headerfun=addheader) - - argnames = [gensym(:MTKArg) for i in 1:length(args)] - symsdict = Dict() - arg_pairs = map((x,y)->vars_to_pairs(x,y, symsdict), argnames, args) - process = unflatten_long_ops∘(x->substitute(x, symsdict, fold=false)) - ls = reduce(vcat,conv.(first.(arg_pairs))) - rs = reduce(vcat,last.(arg_pairs)) - var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, conv.(process.(rs)))) - - fname = gensym(:ModelingToolkitFunction) - op = process(op) - out_expr = conv(substitute(op, symsdict, fold=false)) + linenumbers = true, + headerfun=addheader) - if inner_let !== nothing - inner_let_expr = inner_let(conv ∘ process) - out_expr = inner_let_expr(out_expr) - end - - let_expr = Expr(:let, var_eqs, Expr(:block, out_expr)) - bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end) - - fargs = Expr(:tuple,argnames...) - oop_ex = headerfun(bounds_block, fargs, false) - - if !linenumbers - oop_ex = striplines(oop_ex) - end + dargs = map(destructure_arg, args) + expr = toexpr(Func(dargs, [], value(op))) if expression == Val{true} - return ModelingToolkit.inject_registered_module_functions(oop_ex) + expr else - _build_and_inject_function(@__MODULE__, oop_ex) + _build_and_inject_function(@__MODULE__, expr) end end function _build_and_inject_function(mod::Module, ex) - @RuntimeGeneratedFunction(ModelingToolkit.inject_registered_module_functions(ex)) + @RuntimeGeneratedFunction(ex) end # Detect heterogeneous element types of "arrays of matrices/sparce matrices" diff --git a/src/direct.jl b/src/direct.jl index f4b35385b7..5b19e11998 100644 --- a/src/direct.jl +++ b/src/direct.jl @@ -231,34 +231,6 @@ function sparsehessian(O, vars::AbstractVector; simplify=false) return H end -""" - toexpr(O::Union{Symbolics,Num,Equation,AbstractArray}; canonicalize=true) -> Expr - -Convert `Symbolics` into `Expr`. If `canonicalize`, then we turn exprs like -`x^(-n)` into `inv(x)^n` to avoid type error when evaluating. -""" -function toexpr(O; canonicalize=true) - if canonicalize - canonical, O = canonicalexpr(O) - canonical && return O - else - !istree(O) && return O - end - - op = operation(O) - args = arguments(O) - if op isa Differential - ex = toexpr(args[1]; canonicalize=canonicalize) - wrt = toexpr(op.x; canonicalize=canonicalize) - return :(_derivative($ex, $wrt)) - elseif op isa Sym - isempty(args) && return nameof(op) - return Expr(:call, toexpr(op; canonicalize=canonicalize), toexpr(args; canonicalize=canonicalize)...) - end - return Expr(:call, op, toexpr(args; canonicalize=canonicalize)...) -end -toexpr(s::Sym; kw...) = nameof(s) - """ canonicalexpr(O) -> (canonical::Bool, expr) From 1dccc18ebf3b0ee0cac3bf84cd41d255b0dac6fc Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Mon, 1 Feb 2021 19:30:12 -0500 Subject: [PATCH 11/26] get build_function tests to pass --- src/build_function.jl | 279 ++++++------------------------------------ 1 file changed, 36 insertions(+), 243 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 83da819229..7839f487ea 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -97,11 +97,6 @@ function unflatten_long_ops(op, N=4) Rewriters.Fixpoint(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2])))(op) end -struct Let - eqs::Vector - body -end - function observed_let(eqs) process -> ex -> begin isempty(eqs) && return ex @@ -115,11 +110,6 @@ function observed_let(eqs) end end -function _build_function(target::JuliaTarget, op::Let, args...; conv=toexpr, kw...) - _build_function(target, op.body, args...; - inner_let = observed_let(op.eqs), kw...) -end - # Scalar output function _build_function(target::JuliaTarget, op, args...; @@ -158,6 +148,8 @@ function is_array_array_sparse_matrix(F) return isa(F, AbstractVector) && all(x->isa(x, AbstractArray{<:AbstractSparseMatrix}), F) end +toexpr(n::Num, st) = toexpr(value(n), st) + function fill_array_with_zero!(x::AbstractArray) if eltype(x) <: AbstractArray foreach(fill_array_with_zero!, x) @@ -226,255 +218,56 @@ Special Keyword Argumnets: """ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; conv = toexpr, expression = Val{true}, - inner_let = nothing, checkbounds = false, linenumbers = false, multithread=nothing, - headerfun = addheader, outputidxs=nothing, - convert_oop = true, force_SA = false, - skipzeros = outputidxs===nothing, + outputidxs=nothing, + skipzeros = false, fillzeros = skipzeros && !(typeof(rhss)<:SparseMatrixCSC), parallel=SerialForm(), kwargs...) - if multithread isa Bool - @warn("multithraded is deprecated for the parallel argument. See the documentation.") - parallel = multithread ? MultithreadedForm() : SerialForm() - end - - argnames = [gensym(:MTKArg) for i in 1:length(args)] - symsdict = Dict() - arg_pairs = map((x,y)->vars_to_pairs(x,y, symsdict), argnames, args) - process = unflatten_long_ops∘(x->substitute(x, symsdict, fold=false)) - - ls = reduce(vcat,conv.(first.(arg_pairs))) - rs = reduce(vcat,last.(arg_pairs)) - var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, conv.(process.(rs)))) - - fname = gensym(:ModelingToolkitFunction) - fargs = Expr(:tuple,argnames...) - - - oidx = isnothing(outputidxs) ? (i -> i) : (i -> outputidxs[i]) - X = gensym(:MTIIPVar) - - if rhss isa SparseMatrixCSC - rhs_length = length(rhss.nzval) - rhss = SparseMatrixCSC(rhss.m, rhss.m, rhss.colptr, rhss.rowval, map(process, rhss.nzval)) - else - rhs_length = length(rhss) - rhss = [process(r) for r in rhss] - end - - if parallel isa DistributedForm - numworks = Distributed.nworkers() - reducevars = [gensym(:MTReduceVar) for i in 1:numworks] - lens = Int(ceil(rhs_length/numworks)) - finalsize = rhs_length - (numworks-1)*lens - _rhss = vcat(reduce(vcat,[[Variable(reducevars[i],j) for j in 1:lens] for i in 1:numworks-1],init=Expr[]), - [Variable(reducevars[end],j) for j in 1:finalsize]) - - elseif parallel isa DaggerForm - computevars = [gensym(:MTComputeVar) for i in axes(rhss,1)] - reducevar = Variable(gensym(:MTReduceVar)) - _rhss = [Variable(reducevar,i) for i in axes(rhss,1)] - elseif rhss isa SparseMatrixCSC - _rhss = rhss.nzval - else - _rhss = rhss - end - ip_sys_exprs = Expr[] - # we cannot reliably fill the array with the presence of index translation - if is_array_array_sparse_matrix(rhss) # Array of arrays of sparse matrices - for (i, rhsel) ∈ enumerate(_rhss) - for (j, rhsel2) ∈ enumerate(rhsel) - for (k, rhs) ∈ enumerate(rhsel2.nzval) - rhs′ = conv(rhs) - (skipzeros && rhs′ isa Number && iszero(rhs′)) && continue - push!(ip_sys_exprs, :($X[$i][$j].nzval[$k] = $rhs′)) - end - end - end - elseif is_array_array_matrix(rhss) # Array of arrays of arrays - for (i, rhsel) ∈ enumerate(_rhss) - for (j, rhsel2) ∈ enumerate(rhsel) - for (k, rhs) ∈ enumerate(rhsel2) - rhs′ = conv(rhs) - (skipzeros && rhs′ isa Number && iszero(rhs′)) && continue - push!(ip_sys_exprs, :($X[$i][$j][$k] = $rhs′)) - end - end - end - elseif is_array_sparse_matrix(rhss) # Array of sparse matrices - for (i, rhsel) ∈ enumerate(_rhss) - for (j, rhs) ∈ enumerate(rhsel.nzval) - rhs′ = conv(rhs) - (skipzeros && rhs′ isa Number && iszero(rhs′)) && continue - push!(ip_sys_exprs, :($X[$i].nzval[$j] = $rhs′)) - end - end - elseif is_array_matrix(rhss) # Array of arrays - for (i, rhsel) ∈ enumerate(_rhss) - for (j, rhs) ∈ enumerate(rhsel) - rhs′ = conv(rhs) - (skipzeros && rhs′ isa Number && iszero(rhs′)) && continue - push!(ip_sys_exprs, :($X[$i][$j] = $rhs′)) - end - end - elseif rhss isa SparseMatrixCSC - for (i, rhs) ∈ enumerate(_rhss) - rhs′ = conv(rhs) - (skipzeros && rhs′ isa Number && iszero(rhs′)) && continue - push!(ip_sys_exprs, :($X.nzval[$i] = $rhs′)) - end - else - for (i, rhs) ∈ enumerate(_rhss) - rhs′ = conv(rhs) - (skipzeros && rhs′ isa Number && iszero(rhs′)) && continue - push!(ip_sys_exprs, :($X[$(oidx(i))] = $rhs′)) - end - end - - ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs)) - - if parallel isa MultithreadedForm - lens = Int(ceil(length(ip_let_expr.args[2].args)/Threads.nthreads())) - threaded_exprs = vcat([quote - Threads.@spawn begin - $(ip_let_expr.args[2].args[((i-1)*lens+1):i*lens]...) - end - end for i in 1:Threads.nthreads()-1], - quote - Threads.@spawn begin - $(ip_let_expr.args[2].args[((Threads.nthreads()-1)*lens+1):end]...) - end - end) - ip_let_expr.args[2] = ModelingToolkit.build_expr(:block, threaded_exprs) - ip_let_expr = :(@sync begin $ip_let_expr end) - elseif parallel isa DistributedForm - numworks = Distributed.nworkers() - lens = Int(ceil(length(ip_let_expr.args[2].args)/numworks)) - spawnvars = [gensym(:MTSpawnVar) for i in 1:numworks] - rhss_flat = rhss isa SparseMatrixCSC ? rhss.nzval : rhss - spawnvectors = vcat( - [build_expr(:vect, [conv(rhs) for rhs ∈ rhss_flat[((i-1)*lens+1):i*lens]]) for i in 1:numworks-1], - build_expr(:vect, [conv(rhs) for rhs ∈ rhss_flat[((numworks-1)*lens+1):end]])) - - spawn_exprs = [quote - $(spawnvars[i]) = ModelingToolkit.Distributed.remotecall($(i+1)) do - $(spawnvectors[i]) - end - end for i in 1:numworks] - spawn_exprs = ModelingToolkit.build_expr(:block, spawn_exprs) - resunpack_exprs = [:($(Symbol(reducevars[iter])) = fetch($(spawnvars[iter]))) for iter in 1:numworks] - - ip_let_expr.args[2] = quote - @sync begin - $spawn_exprs - $(resunpack_exprs...) - $(ip_let_expr.args[2]) - end - end - elseif parallel isa DaggerForm - @assert HAS_DAGGER[] "Dagger.jl is not loaded; please do `using Dagger`" - dagwrap(x) = x - dagwrap(ex::Expr) = dagwrap(ex, Val(ex.head)) - dagwrap(ex::Expr, ::Val) = ex - dagwrap(ex::Expr, ::Val{:call}) = :(Dagger.delayed($(ex.args[1]))($(dagwrap.(ex.args[2:end])...))) - new_rhss = dagwrap.(conv.(rhss)) - delayed_exprs = build_expr(:block, [:($(Symbol(computevars[i])) = Dagger.delayed(identity)($(new_rhss[i]))) for i in axes(computevars,1)]) - # TODO: treereduce? - reduce_expr = quote - $(Symbol(reducevar)) = collect(Dagger.delayed(vcat)($(computevars...))) - end - ip_let_expr.args[2] = quote - @sync begin - $delayed_exprs - $reduce_expr - $(ip_let_expr.args[2]) - end - end - end + dargs = map(destructure_arg, args) + i = findfirst(x->x isa DestructuredArgs, dargs) + similarto = i === nothing ? Array : dargs[i].name + array_expr = _make_array(rhss, similarto) + oop_expr = Func(dargs, [], array_expr) + out = Sym{Any}(gensym("out")) if rhss isa SparseMatrixCSC - rhss′ = map(conv∘process, rhss.nzval) - else - rhss′ = [conv(process(r)) for r in rhss] - end - - tuple_sys_expr = build_expr(:tuple, rhss′) - - if rhss isa Matrix - arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs ∈ rhss[i,:]]) for i in 1:size(rhss,1)]) - elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector) - vector_form = build_expr(:vect, [conv(rhs) for rhs ∈ rhss]) - arr_sys_expr = :(reshape($vector_form,$(size(rhss)...))) - elseif rhss isa SparseMatrixCSC - vector_form = build_expr(:vect, [conv(rhs) for rhs ∈ nonzeros(rhss)]) - arr_sys_expr = :(SparseMatrixCSC{eltype($(first(argnames))),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form)) - else # Vector - arr_sys_expr = build_expr(:vect, [conv(rhs) for rhs ∈ rhss]) - end - - xname = gensym(:MTK) - - arr_sys_expr = (typeof(rhss) <: Vector || typeof(rhss) <: Matrix) && !(eltype(rhss) <: AbstractArray) ? quote - if $force_SA || typeof($(fargs.args[1])) <: Union{ModelingToolkit.StaticArrays.SArray,ModelingToolkit.LabelledArrays.SLArray} - $xname = ModelingToolkit.StaticArrays.@SArray $arr_sys_expr - if $convert_oop && !(typeof($(fargs.args[1])) <: Number) && $(typeof(rhss) <: Vector) # Only try converting if it should match `u` - return similar_type($(fargs.args[1]),eltype($xname))($xname) - else - return $xname - end - else - $xname = $arr_sys_expr - if $convert_oop && $(typeof(rhss) <: Vector) - if !(typeof($(fargs.args[1])) <: Array) && !(typeof($(fargs.args[1])) <: Number) && eltype($(fargs.args[1])) <: eltype($xname) - # Last condition: avoid known error because this doesn't change eltypes! - return convert(typeof($(fargs.args[1])),$xname) - elseif typeof($(fargs.args[1])) <: ModelingToolkit.LabelledArrays.LArray - # LArray just needs to add the names back! - return ModelingToolkit.LabelledArrays.LArray{ModelingToolkit.LabelledArrays.symnames(typeof($(fargs.args[1])))}($xname) - else - return $xname - end - else - return $xname - end - end - end : arr_sys_expr - - if inner_let !== nothing - inner_let_expr = inner_let(conv ∘ process) - arr_sys_expr = inner_let_expr(arr_sys_expr) - ip_let_expr.args[2] = inner_let_expr(ip_let_expr.args[2]) + I,J, _ = findnz(rhss) + outputidxs = CartesianIndex.(I, J) + elseif rhss isa SparseVector + I,_ = findnz(rhss) + outputidxs = I + elseif isnothing(outputidxs) + outputidxs = collect(eachindex(rhss)) end - if fillzeros && outputidxs === nothing - ip_let_expr = quote - $fill_array_with_zero!($X) - $ip_let_expr - end - end - - arr_let_expr = Expr(:let, var_eqs, arr_sys_expr) - - oop_bounds_block = checkbounds ? arr_let_expr : :(@inbounds $arr_let_expr) - ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds $ip_let_expr) - - oop_ex = headerfun(oop_bounds_block, fargs, false) - iip_ex = headerfun(ip_bounds_block, fargs, true; X=X) - - if !linenumbers - oop_ex = striplines(oop_ex) - iip_ex = striplines(iip_ex) + if skipzeros + ii = findall(i->!_iszero(rhss[i]), outputidxs) + array = AtIndex.(outputidxs[ii], rhss[ii]) + else + array = AtIndex.(outputidxs, rhss) end + ip_expr = Func([out, dargs...], [], SetArray(false, out, array)) if expression == Val{true} - return ModelingToolkit.inject_registered_module_functions(oop_ex), ModelingToolkit.inject_registered_module_functions(iip_ex) + return toexpr(oop_expr), toexpr(ip_expr) else - return _build_and_inject_function(@__MODULE__, oop_ex), _build_and_inject_function(@__MODULE__, iip_ex) + return _build_and_inject_function(@__MODULE__, toexpr(oop_expr)), + _build_and_inject_function(@__MODULE__, toexpr(ip_expr)) end end +function _make_array(rhss::AbstractSparseArray, similarto) + MakeSparseArray(map(x->_make_array(x, similarto), rhss)) +end + +function _make_array(rhss::AbstractArray, similarto) + MakeArray(map(x->_make_array(x, similarto), rhss), similarto) +end + +_make_array(x, similarto) = x + function vars_to_pairs(name,vs::Union{Tuple, AbstractArray}, symsdict=Dict()) vs_names = tosymbol.(vs) for (v,k) in zip(vs_names, vs) From b0cf5f1a9cf0ff32eaaa9a661be66f19bf2c318c Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Mon, 1 Feb 2021 19:48:09 -0500 Subject: [PATCH 12/26] fixes --- src/build_function.jl | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index d52199238c..3cc04bddac 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -1,3 +1,5 @@ +using SymbolicUtils.Code + abstract type BuildTargets end struct JuliaTarget <: BuildTargets end struct StanTarget <: BuildTargets end @@ -48,23 +50,6 @@ function build_function(args...;target = JuliaTarget(),kwargs...) _build_function(target,args...;kwargs...) end -function addheader(ex, fargs, iip; X=gensym(:MTIIPVar)) - if iip - wrappedex = :( - ($X,$(fargs.args...)) -> begin - $ex - nothing - end - ) - else - wrappedex = :( - ($(fargs.args...),) -> begin - $ex - end - ) - end - wrappedex -end function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar)) integrator = gensym(:MTKIntegrator) @@ -112,12 +97,14 @@ end # Scalar output +destructure_arg(arg::Union{AbstractArray, Tuple}) = DestructuredArgs(map(value, arg)) +destructure_arg(arg) = value(arg) + function _build_function(target::JuliaTarget, op, args...; conv = toexpr, expression = Val{true}, checkbounds = false, - linenumbers = true, - headerfun=addheader) + linenumbers = true) dargs = map(destructure_arg, args) expr = toexpr(Func(dargs, [], value(op))) From e1fd36b129de2a7a4c9931f52059a1b2d4cfcf9f Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Mon, 1 Feb 2021 19:53:00 -0500 Subject: [PATCH 13/26] more fixes --- src/build_function.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 3cc04bddac..1dc19dfe43 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -220,10 +220,10 @@ Special Keyword Argumnets: out = Sym{Any}(gensym("out")) if rhss isa SparseMatrixCSC - I,J, _ = findnz(rhss) + I,J, rhss = findnz(rhss) outputidxs = CartesianIndex.(I, J) elseif rhss isa SparseVector - I,_ = findnz(rhss) + I, rhss = findnz(rhss) outputidxs = I elseif isnothing(outputidxs) outputidxs = collect(eachindex(rhss)) @@ -233,7 +233,7 @@ Special Keyword Argumnets: ii = findall(i->!_iszero(rhss[i]), outputidxs) array = AtIndex.(outputidxs[ii], rhss[ii]) else - array = AtIndex.(outputidxs, rhss) + array = AtIndex.(vec(outputidxs), vec(rhss)) end ip_expr = Func([out, dargs...], [], SetArray(false, out, array)) From 82db1574e5160efccf95c0437629ce1ab744444b Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Mon, 1 Feb 2021 20:53:55 -0500 Subject: [PATCH 14/26] delete some substitution code Co-authored-by: "Yingbo Ma" --- src/systems/diffeqs/abstractodesystem.jl | 47 ++++++++---------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 0d45982a28..1036883a93 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -35,57 +35,42 @@ function calculate_jacobian(sys::AbstractODESystem; return jac end -struct ODEToExpr - sys::AbstractODESystem - states::Vector -end -ODEToExpr(@nospecialize(sys)) = ODEToExpr(sys,states(sys)) -(f::ODEToExpr)(O::Num) = f(value(O)) -function (f::ODEToExpr)(O::Term) - if isa(operation(O), Sym) - any(isequal(O), f.states) && return tosymbol(O) - # dependent variables - return build_expr(:call, Any[operation(O).name; f.(arguments(O))]) - end - return build_expr(:call, Any[operation(O); f.(arguments(O))]) -end -(f::ODEToExpr)(x) = toexpr(x) - function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); simplify=false, kwargs...) tgrad = calculate_tgrad(sys,simplify=simplify) - return build_function(tgrad, dvs, ps, sys.iv; - conv = ODEToExpr(sys), kwargs...) + return build_function(tgrad, dvs, ps, sys.iv; kwargs...) end function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); simplify=false, sparse = false, kwargs...) jac = calculate_jacobian(sys;simplify=simplify,sparse=sparse) - sub = Dict(value.(dvs) .=> makesym.(value.(dvs))) - jac = map(d->substitute(d, sub), jac) - return build_function(jac, dvs, ps, sys.iv; - conv = ODEToExpr(sys), kwargs...) + return build_function(jac, dvs, ps, sys.iv; kwargs...) end function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); kwargs...) # optimization #obsvars = map(eq->eq.lhs, observed(sys)) #fulldvs = [dvs; obsvars] - fulldvs = dvs - fulldvs′ = makesym.(value.(fulldvs)) - sub = Dict(fulldvs .=> fulldvs′) # substitute x(t) by just x - rhss = [substitute(deq.rhs, sub) for deq ∈ equations(sys)] + rhss = [deq.rhs for deq ∈ equations(sys)] #obss = [makesym(value(eq.lhs)) ~ substitute(eq.rhs, sub) for eq ∈ observed(sys)] #rhss = Let(obss, rhss) - dvs′ = fulldvs′[1:length(dvs)] - ps′ = makesym.(value.(ps), states=()) - # TODO: add an optional check on the ordering of observed equations - return build_function(rhss, dvs′, ps′, sys.iv; - conv = ODEToExpr(sys),kwargs...) + return build_function(rhss, + map(x->uncall_delayed_var(value(x), sys), dvs), + map(x->uncall_delayed_var(value(x), sys), ps), + sys.iv; kwargs...) +end + +function uncall_delayed_var(x, sys) + if istree(x) && + operation(x) isa Sym && + !(length(arguments(x)) == 1 && isequal(arguments(x)[1], sys.iv)) + return operation(x) + end + return x end function calculate_massmatrix(sys::AbstractODESystem; simplify=false) From 27f423a166523ecb38b3dd43bc2e46bd04afc610 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Mon, 1 Feb 2021 20:59:06 -0500 Subject: [PATCH 15/26] naming and commenting Co-authored-by: "Yingbo Ma" --- src/systems/diffeqs/abstractodesystem.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 1036883a93..4edc922eea 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -59,15 +59,20 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param # TODO: add an optional check on the ordering of observed equations return build_function(rhss, - map(x->uncall_delayed_var(value(x), sys), dvs), - map(x->uncall_delayed_var(value(x), sys), ps), + map(x->time_varying_as_func(value(x), sys), dvs), + map(x->time_varying_as_func(value(x), sys), ps), sys.iv; kwargs...) end -function uncall_delayed_var(x, sys) +function time_varying_as_func(x, sys) + # if something is not x(t) (the current state) + # but is `x(t-1)` or something like that, pass in `x` as a callable function rather + # than pass in a value in place of x(t). + # + # This is done by just making `x` the argument of the function. if istree(x) && operation(x) isa Sym && - !(length(arguments(x)) == 1 && isequal(arguments(x)[1], sys.iv)) + !(length(arguments(x)) == 1 && isequal(arguments(x)[1], independent_variable(sys))) return operation(x) end return x From cd95dcacc29915f0882aa6952c6201e748d2a58c Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Mon, 1 Feb 2021 21:21:31 -0500 Subject: [PATCH 16/26] get more tests to pass Co-authored-by: "Yingbo Ma" --- src/systems/diffeqs/sdesystem.jl | 8 +++++--- test/labelledarrays.jl | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 7158483fd6..71dece5561 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -98,9 +98,11 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps; SDESystem(deqs, neqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, default_u0, default_p) end -function generate_diffusion_function(sys::SDESystem, dvs = sys.states, ps = sys.ps; kwargs...) - return build_function(sys.noiseeqs, dvs, ps, sys.iv; - conv = ODEToExpr(sys),kwargs...) +function generate_diffusion_function(sys::SDESystem, dvs = states(sys), ps = parameters(sys); kwargs...) + return build_function(sys.noiseeqs, + map(x->time_varying_as_func(value(x), sys), dvs), + map(x->time_varying_as_func(value(x), sys), ps), + sys.iv; kwargs...) end """ diff --git a/test/labelledarrays.jl b/test/labelledarrays.jl index 67ed915ffe..5122cb6183 100644 --- a/test/labelledarrays.jl +++ b/test/labelledarrays.jl @@ -38,4 +38,5 @@ d = LVector(x=1.0,y=2.0,z=3.0) @test ff.jac(b,p,ForwardDiff.Dual(0.0,1.0)) isa SArray @test eltype(ff.jac(b,p,ForwardDiff.Dual(0.0,1.0))) <: ForwardDiff.Dual @test ff.jac(d,p,ForwardDiff.Dual(0.0,1.0)) isa Array +@inferred ff.jac(d,p,ForwardDiff.Dual(0.0,1.0)) @test eltype(ff.jac(d,p,ForwardDiff.Dual(0.0,1.0))) <: ForwardDiff.Dual From 4751a2d5b70e5d6a4aa07136b6624ee6911a4ebe Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Tue, 2 Feb 2021 10:08:45 -0500 Subject: [PATCH 17/26] the alternative to adding headers --- src/build_function.jl | 56 +++++++++++++-------------------- src/systems/jumps/jumpsystem.jl | 15 ++++++--- test/reactionsystem.jl | 6 ++-- 3 files changed, 34 insertions(+), 43 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 1dc19dfe43..8d1f76ca44 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -50,28 +50,6 @@ function build_function(args...;target = JuliaTarget(),kwargs...) _build_function(target,args...;kwargs...) end - -function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar)) - integrator = gensym(:MTKIntegrator) - if iip - wrappedex = :( - $integrator -> begin - ($X,$(fargs.args...)) = (($integrator).u,($integrator).u,($integrator).p,($integrator).t) - $ex - nothing - end - ) - else - wrappedex = :( - $integrator -> begin - ($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t) - $ex - end - ) - end - wrappedex -end - function unflatten_long_ops(op, N=4) rule1 = @rule((+)((~~x)) => length(~~x) > N ? +(+((~~x)[1:N]...) + (+)((~~x)[N+1:end]...)) : nothing) @@ -106,7 +84,7 @@ function _build_function(target::JuliaTarget, op, args...; checkbounds = false, linenumbers = true) - dargs = map(destructure_arg, args) + dargs = map(destructure_arg, [args...]) expr = toexpr(Func(dargs, [], value(op))) if expression == Val{true} @@ -202,21 +180,25 @@ Special Keyword Argumnets: filling function is 0. - `fillzeros`: Whether to perform `fill(out,0)` before the calculations to ensure safety with `skipzeros`. - """ - function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; - conv = toexpr, expression = Val{true}, - checkbounds = false, - linenumbers = false, multithread=nothing, - outputidxs=nothing, - skipzeros = false, - fillzeros = skipzeros && !(typeof(rhss)<:SparseMatrixCSC), - parallel=SerialForm(), kwargs...) - - dargs = map(destructure_arg, args) +""" +function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; + conv = toexpr, expression = Val{true}, + checkbounds = false, + linenumbers = false, multithread=nothing, + outputidxs=nothing, + skipzeros = false, + wrap_code = (nothing, nothing), + fillzeros = skipzeros && !(typeof(rhss)<:SparseMatrixCSC), + parallel=SerialForm(), kwargs...) + + dargs = map(destructure_arg, [args...]) i = findfirst(x->x isa DestructuredArgs, dargs) similarto = i === nothing ? Array : dargs[i].name array_expr = _make_array(rhss, similarto) oop_expr = Func(dargs, [], array_expr) + if !isnothing(wrap_code[1]) + oop_expr = wrap_code[1](oop_expr) + end out = Sym{Any}(gensym("out")) if rhss isa SparseMatrixCSC @@ -233,9 +215,13 @@ Special Keyword Argumnets: ii = findall(i->!_iszero(rhss[i]), outputidxs) array = AtIndex.(outputidxs[ii], rhss[ii]) else - array = AtIndex.(vec(outputidxs), vec(rhss)) + # sometimes outputidxs is a Tuple + array = AtIndex.(vec(collect(outputidxs)), vec(rhss)) end ip_expr = Func([out, dargs...], [], SetArray(false, out, array)) + if !isnothing(wrap_code[2]) + ip_expr = wrap_code[2](ip_expr) + end if expression == Val{true} return toexpr(oop_expr), toexpr(ip_expr) diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index c183813dc5..c84715afa8 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -80,26 +80,31 @@ function JumpSystem(eqs, iv, states, ps; end function generate_rate_function(js, rate) - build_function(rate, states(js), parameters(js), + rf = build_function(rate, states(js), parameters(js), independent_variable(js), conv = states_to_sym(states(js)), expression=Val{true}) end +function add_integrator_header() + integrator = gensym(:MTKIntegrator) + + expr -> Func([DestructuredArgs(expr.args, integrator, inds=[:u, :p, :t])], [], expr.body), + expr -> Func([DestructuredArgs(expr.args, integrator, inds=[:u, :u, :p, :t])], [], expr.body) +end function generate_affect_function(js, affect, outputidxs) bf = build_function(map(x->x isa Equation ? x.rhs : x , affect), states(js), parameters(js), - conv = states_to_sym(states(js)), independent_variable(js), expression=Val{true}, - headerfun=add_integrator_header, + wrap_code=add_integrator_header(), outputidxs=outputidxs)[2] end function assemble_vrj(js, vrj, statetoid) rate = @RuntimeGeneratedFunction(generate_rate_function(js, vrj.rate)) outputvars = (value(affect.lhs) for affect in vrj.affect!) - outputidxs = ((statetoid[var] for var in outputvars)...,) + outputidxs = [statetoid[var] for var in outputvars] affect = @RuntimeGeneratedFunction(generate_affect_function(js, vrj.affect!, outputidxs)) VariableRateJump(rate, affect) end @@ -119,7 +124,7 @@ end function assemble_crj(js, crj, statetoid) rate = @RuntimeGeneratedFunction(generate_rate_function(js, crj.rate)) outputvars = (value(affect.lhs) for affect in crj.affect!) - outputidxs = ((statetoid[var] for var in outputvars)...,) + outputidxs = [statetoid[var] for var in outputvars] affect = @RuntimeGeneratedFunction(generate_affect_function(js, crj.affect!, outputidxs)) ConstantRateJump(rate, affect) end diff --git a/test/reactionsystem.jl b/test/reactionsystem.jl index f146b6b7e0..6ba1efbd29 100644 --- a/test/reactionsystem.jl +++ b/test/reactionsystem.jl @@ -130,7 +130,7 @@ vidxs = 19:20 @test all(map(i -> typeof(js.eqs[i]) <: DiffEqJump.ConstantRateJump, cidxs)) @test all(map(i -> typeof(js.eqs[i]) <: DiffEqJump.VariableRateJump, vidxs)) -pars = rand(length(k)); u0 = rand(1:10,4); time = rand(); +pars = rand(length(k)); u0 = rand(1:10,4); ttt = rand(); jumps = Vector{Union{ConstantRateJump, MassActionJump, VariableRateJump}}(undef,length(rxs)) jumps[1] = MassActionJump(pars[1], Vector{Pair{Int,Int}}(), [1 => 1]); @@ -166,7 +166,7 @@ for i in midxs end for i in cidxs crj = MT.assemble_crj(js, js.eqs[i], statetoid) - @test isapprox(crj.rate(u0,p,time), jumps[i].rate(u0,p,time)) + @test isapprox(crj.rate(u0,p,ttt), jumps[i].rate(u0,p,ttt)) fake_integrator1 = (u=zeros(4),p=p,t=0); fake_integrator2 = deepcopy(fake_integrator1); crj.affect!(fake_integrator1); jumps[i].affect!(fake_integrator2); @@ -174,7 +174,7 @@ for i in cidxs end for i in vidxs crj = MT.assemble_vrj(js, js.eqs[i], statetoid) - @test isapprox(crj.rate(u0,p,time), jumps[i].rate(u0,p,time)) + @test isapprox(crj.rate(u0,p,ttt), jumps[i].rate(u0,p,ttt)) fake_integrator1 = (u=zeros(4),p=p,t=0.); fake_integrator2 = deepcopy(fake_integrator1); crj.affect!(fake_integrator1); jumps[i].affect!(fake_integrator2); @test fake_integrator1 == fake_integrator2 From 7c7be57ad321b11f07fd32f0df213aa69aba7261 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Tue, 2 Feb 2021 11:02:41 -0500 Subject: [PATCH 18/26] handle some bs behavior of map on Reshaped Sparse array --- src/build_function.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 8d1f76ca44..9046c18757 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -76,7 +76,7 @@ end # Scalar output destructure_arg(arg::Union{AbstractArray, Tuple}) = DestructuredArgs(map(value, arg)) -destructure_arg(arg) = value(arg) +destructure_arg(arg) = arg function _build_function(target::JuliaTarget, op, args...; conv = toexpr, @@ -85,7 +85,7 @@ function _build_function(target::JuliaTarget, op, args...; linenumbers = true) dargs = map(destructure_arg, [args...]) - expr = toexpr(Func(dargs, [], value(op))) + expr = toexpr(Func(dargs, [], op)) if expression == Val{true} expr @@ -200,6 +200,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; oop_expr = wrap_code[1](oop_expr) end + ## In-place version out = Sym{Any}(gensym("out")) if rhss isa SparseMatrixCSC I,J, rhss = findnz(rhss) @@ -232,11 +233,22 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; end function _make_array(rhss::AbstractSparseArray, similarto) - MakeSparseArray(map(x->_make_array(x, similarto), rhss)) + arr = map(x->_make_array(x, similarto), rhss) + if !(arr isa AbstractSparseArray) + _make_array(arr, similarto) + else + MakeSparseArray(arr) + end end function _make_array(rhss::AbstractArray, similarto) - MakeArray(map(x->_make_array(x, similarto), rhss), similarto) + arr = map(x->_make_array(x, similarto), rhss) + # Ugh reshaped array of a sparse array when mapped gives a sparse array + if arr isa AbstractSparseArray + _make_array(arr, similarto) + else + MakeArray(arr, similarto) + end end _make_array(x, similarto) = x From c1967e8b676474d4a06899f15fd3e0501bc61c05 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Thu, 4 Feb 2021 08:37:08 -0500 Subject: [PATCH 19/26] forward module in build_function --- src/build_function.jl | 20 +++++++++------- test/runtests.jl | 56 +++++++++++++++++++++---------------------- 2 files changed, 40 insertions(+), 36 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 9046c18757..b6f0b0f3c6 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -81,6 +81,7 @@ destructure_arg(arg) = arg function _build_function(target::JuliaTarget, op, args...; conv = toexpr, expression = Val{true}, + expression_module = @__MODULE__(), checkbounds = false, linenumbers = true) @@ -90,12 +91,12 @@ function _build_function(target::JuliaTarget, op, args...; if expression == Val{true} expr else - _build_and_inject_function(@__MODULE__, expr) + _build_and_inject_function(expression_module, expr) end end function _build_and_inject_function(mod::Module, ex) - @RuntimeGeneratedFunction(ex) + @RuntimeGeneratedFunction(mod, ex) end # Detect heterogeneous element types of "arrays of matrices/sparce matrices" @@ -182,7 +183,8 @@ Special Keyword Argumnets: safety with `skipzeros`. """ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; - conv = toexpr, expression = Val{true}, + expression = Val{true}, + expression_module = @__MODULE__(), checkbounds = false, linenumbers = false, multithread=nothing, outputidxs=nothing, @@ -213,11 +215,13 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; end if skipzeros - ii = findall(i->!_iszero(rhss[i]), outputidxs) - array = AtIndex.(outputidxs[ii], rhss[ii]) + ii = findall(i->!iszero(rhss[i]), outputidxs) + array = AtIndex.(outputidxs[ii], + map(x->_make_array(x, similarto), rhss[ii])) else # sometimes outputidxs is a Tuple - array = AtIndex.(vec(collect(outputidxs)), vec(rhss)) + array = AtIndex.(vec(collect(outputidxs)), + map(x->_make_array(x, similarto), vec(rhss))) end ip_expr = Func([out, dargs...], [], SetArray(false, out, array)) if !isnothing(wrap_code[2]) @@ -227,8 +231,8 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; if expression == Val{true} return toexpr(oop_expr), toexpr(ip_expr) else - return _build_and_inject_function(@__MODULE__, toexpr(oop_expr)), - _build_and_inject_function(@__MODULE__, toexpr(ip_expr)) + return _build_and_inject_function(expression_module, toexpr(oop_expr)), + _build_and_inject_function(expression_module, toexpr(ip_expr)) end end diff --git a/test/runtests.jl b/test/runtests.jl index 8be142e333..45ebea3490 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,33 +1,33 @@ using SafeTestsets, Test -@safetestset "Parsing Test" begin include("variable_parsing.jl") end -@safetestset "Differentiation Test" begin include("derivatives.jl") end -@safetestset "Simplify Test" begin include("simplify.jl") end -@safetestset "Operation Overloads Test" begin include("operation_overloads.jl") end -@safetestset "Direct Usage Test" begin include("direct.jl") end -@safetestset "System Linearity Test" begin include("linearity.jl") end -@safetestset "Build Function Test" begin include("build_function.jl") end -@safetestset "ODESystem Test" begin include("odesystem.jl") end -@safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end -@safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end -@safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end -@safetestset "SDESystem Test" begin include("sdesystem.jl") end -@safetestset "NonlinearSystem Test" begin include("nonlinearsystem.jl") end -@safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end -@safetestset "ReactionSystem Test" begin include("reactionsystem.jl") end -@safetestset "JumpSystem Test" begin include("jumpsystem.jl") end -@safetestset "ControlSystem Test" begin include("controlsystem.jl") end -@safetestset "Build Targets Test" begin include("build_targets.jl") end -@safetestset "Domain Test" begin include("domains.jl") end -@safetestset "Modelingtoolkitize Test" begin include("modelingtoolkitize.jl") end -@safetestset "Constraints Test" begin include("constraints.jl") end -@safetestset "Reduction Test" begin include("reduction.jl") end -@safetestset "Components Test" begin include("components.jl") end -@safetestset "PDE Construction Test" begin include("pde.jl") end -@safetestset "Lowering Integration Test" begin include("lowering_solving.jl") end -@safetestset "Test Big System Usage" begin include("bigsystem.jl") end -@safetestset "Depdendency Graph Test" begin include("dep_graphs.jl") end -@safetestset "Function Registration Test" begin include("function_registration.jl") end +# @safetestset "Parsing Test" begin include("variable_parsing.jl") end +# @safetestset "Differentiation Test" begin include("derivatives.jl") end +# @safetestset "Simplify Test" begin include("simplify.jl") end +# @safetestset "Operation Overloads Test" begin include("operation_overloads.jl") end +# @safetestset "Direct Usage Test" begin include("direct.jl") end +# @safetestset "System Linearity Test" begin include("linearity.jl") end +# @safetestset "Build Function Test" begin include("build_function.jl") end +# @safetestset "ODESystem Test" begin include("odesystem.jl") end +# @safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end +# @safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end +# @safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end +# @safetestset "SDESystem Test" begin include("sdesystem.jl") end +# @safetestset "NonlinearSystem Test" begin include("nonlinearsystem.jl") end +# @safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end +# @safetestset "ReactionSystem Test" begin include("reactionsystem.jl") end +# @safetestset "JumpSystem Test" begin include("jumpsystem.jl") end +# @safetestset "ControlSystem Test" begin include("controlsystem.jl") end +# @safetestset "Build Targets Test" begin include("build_targets.jl") end +# @safetestset "Domain Test" begin include("domains.jl") end +# @safetestset "Modelingtoolkitize Test" begin include("modelingtoolkitize.jl") end +# @safetestset "Constraints Test" begin include("constraints.jl") end +# @safetestset "Reduction Test" begin include("reduction.jl") end +# @safetestset "Components Test" begin include("components.jl") end +# @safetestset "PDE Construction Test" begin include("pde.jl") end +# @safetestset "Lowering Integration Test" begin include("lowering_solving.jl") end +# @safetestset "Test Big System Usage" begin include("bigsystem.jl") end +# @safetestset "Depdendency Graph Test" begin include("dep_graphs.jl") end +# @safetestset "Function Registration Test" begin include("function_registration.jl") end @safetestset "Array of Array Test" begin include("build_function_arrayofarray.jl") end @testset "Distributed Test" begin include("distributed.jl") end @safetestset "Variable Utils Test" begin include("variable_utils.jl") end From 87568a9968eaad7c943eabb379567be50a1e7dda Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Thu, 4 Feb 2021 18:18:25 -0500 Subject: [PATCH 20/26] try qualifying the function --- src/build_function.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/build_function.jl b/src/build_function.jl index b6f0b0f3c6..367e739ab7 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -96,6 +96,11 @@ function _build_function(target::JuliaTarget, op, args...; end function _build_and_inject_function(mod::Module, ex) + if ex.head == :function && ex.args[1].head == :tuple + ex.args[1] = Expr(:call, :($mod.$(gensym())), ex.args[1].args...) + elseif ex.head == :(->) + return _build_and_inject_function(mod, Expr(:function, ex.args...)) + end @RuntimeGeneratedFunction(mod, ex) end From b25faaca3127390f5ead40b40bdd2c0bc6fff4be Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Thu, 4 Feb 2021 19:04:45 -0500 Subject: [PATCH 21/26] better recursive _set_array, assume output is already of the right sparsity --- src/build_function.jl | 50 +++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 367e739ab7..b19b9ca043 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -201,34 +201,14 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; dargs = map(destructure_arg, [args...]) i = findfirst(x->x isa DestructuredArgs, dargs) similarto = i === nothing ? Array : dargs[i].name - array_expr = _make_array(rhss, similarto) - oop_expr = Func(dargs, [], array_expr) + oop_expr = Func(dargs, [], _make_array(rhss, similarto)) if !isnothing(wrap_code[1]) oop_expr = wrap_code[1](oop_expr) end - ## In-place version out = Sym{Any}(gensym("out")) - if rhss isa SparseMatrixCSC - I,J, rhss = findnz(rhss) - outputidxs = CartesianIndex.(I, J) - elseif rhss isa SparseVector - I, rhss = findnz(rhss) - outputidxs = I - elseif isnothing(outputidxs) - outputidxs = collect(eachindex(rhss)) - end + ip_expr = Func([out, dargs...], [], _set_array(out, outputidxs, rhss, skipzeros)) - if skipzeros - ii = findall(i->!iszero(rhss[i]), outputidxs) - array = AtIndex.(outputidxs[ii], - map(x->_make_array(x, similarto), rhss[ii])) - else - # sometimes outputidxs is a Tuple - array = AtIndex.(vec(collect(outputidxs)), - map(x->_make_array(x, similarto), vec(rhss))) - end - ip_expr = Func([out, dargs...], [], SetArray(false, out, array)) if !isnothing(wrap_code[2]) ip_expr = wrap_code[2](ip_expr) end @@ -237,7 +217,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; return toexpr(oop_expr), toexpr(ip_expr) else return _build_and_inject_function(expression_module, toexpr(oop_expr)), - _build_and_inject_function(expression_module, toexpr(ip_expr)) + _build_and_inject_function(expression_module, toexpr(ip_expr)) end end @@ -250,6 +230,30 @@ function _make_array(rhss::AbstractSparseArray, similarto) end end +## In-place version +function _set_array(out, outputidxs, rhss::AbstractArray, skipzeros) + if rhss isa Union{SparseVector, SparseMatrixCSC} + return SetArray(false, LiteralExpr(:($out.nzval)), rhss.nzval) + elseif isnothing(outputidxs) + outputidxs = collect(eachindex(rhss)) + end + + # sometimes outputidxs is a Tuple + ii = findall(i->!(rhss[i] isa AbstractArray) && !(skipzeros && iszero(rhss[i])), outputidxs) + jj = findall(i->rhss[i] isa AbstractArray, outputidxs) + exprs = [] + push!(exprs, SetArray(false, out, AtIndex.(vec(collect(outputidxs[ii])), vec(rhss[ii])))) + for j in jj + push!(exprs, _set_array(LiteralExpr(:($out[$j])), nothing, rhss[j], skipzeros)) + end + LiteralExpr(quote + $(exprs...) + end) +end + +_set_array(out, outputidxs, rhs, skipzeros) = rhs + + function _make_array(rhss::AbstractArray, similarto) arr = map(x->_make_array(x, similarto), rhss) # Ugh reshaped array of a sparse array when mapped gives a sparse array From 131d1da0cb26c4965a1ef6b6f464a3741bdba40b Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 4 Feb 2021 16:56:59 -0500 Subject: [PATCH 22/26] Pass in checkbounds and fix outputidxs --- src/build_function.jl | 18 +++++++------- test/runtests.jl | 56 +++++++++++++++++++++---------------------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index b19b9ca043..8d157f200e 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -207,7 +207,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; end out = Sym{Any}(gensym("out")) - ip_expr = Func([out, dargs...], [], _set_array(out, outputidxs, rhss, skipzeros)) + ip_expr = Func([out, dargs...], [], _set_array(out, outputidxs, rhss, checkbounds, skipzeros)) if !isnothing(wrap_code[2]) ip_expr = wrap_code[2](ip_expr) @@ -231,27 +231,27 @@ function _make_array(rhss::AbstractSparseArray, similarto) end ## In-place version -function _set_array(out, outputidxs, rhss::AbstractArray, skipzeros) +function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros) if rhss isa Union{SparseVector, SparseMatrixCSC} - return SetArray(false, LiteralExpr(:($out.nzval)), rhss.nzval) - elseif isnothing(outputidxs) + return SetArray(checkbounds, LiteralExpr(:($out.nzval)), rhss.nzval) + elseif outputidxs === nothing outputidxs = collect(eachindex(rhss)) end # sometimes outputidxs is a Tuple - ii = findall(i->!(rhss[i] isa AbstractArray) && !(skipzeros && iszero(rhss[i])), outputidxs) - jj = findall(i->rhss[i] isa AbstractArray, outputidxs) + ii = findall(i->!(rhss[i] isa AbstractArray) && !(skipzeros && _iszero(rhss[i])), eachindex(outputidxs)) + jj = findall(i->rhss[i] isa AbstractArray, eachindex(outputidxs)) exprs = [] - push!(exprs, SetArray(false, out, AtIndex.(vec(collect(outputidxs[ii])), vec(rhss[ii])))) + push!(exprs, SetArray(checkbounds, out, AtIndex.(vec(collect(outputidxs[ii])), vec(rhss[ii])))) for j in jj - push!(exprs, _set_array(LiteralExpr(:($out[$j])), nothing, rhss[j], skipzeros)) + push!(exprs, _set_array(LiteralExpr(:($out[$j])), nothing, rhss[j], checkbounds, skipzeros)) end LiteralExpr(quote $(exprs...) end) end -_set_array(out, outputidxs, rhs, skipzeros) = rhs +_set_array(out, outputidxs, rhs, checkbounds, skipzeros) = rhs function _make_array(rhss::AbstractArray, similarto) diff --git a/test/runtests.jl b/test/runtests.jl index 45ebea3490..8be142e333 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,33 +1,33 @@ using SafeTestsets, Test -# @safetestset "Parsing Test" begin include("variable_parsing.jl") end -# @safetestset "Differentiation Test" begin include("derivatives.jl") end -# @safetestset "Simplify Test" begin include("simplify.jl") end -# @safetestset "Operation Overloads Test" begin include("operation_overloads.jl") end -# @safetestset "Direct Usage Test" begin include("direct.jl") end -# @safetestset "System Linearity Test" begin include("linearity.jl") end -# @safetestset "Build Function Test" begin include("build_function.jl") end -# @safetestset "ODESystem Test" begin include("odesystem.jl") end -# @safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end -# @safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end -# @safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end -# @safetestset "SDESystem Test" begin include("sdesystem.jl") end -# @safetestset "NonlinearSystem Test" begin include("nonlinearsystem.jl") end -# @safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end -# @safetestset "ReactionSystem Test" begin include("reactionsystem.jl") end -# @safetestset "JumpSystem Test" begin include("jumpsystem.jl") end -# @safetestset "ControlSystem Test" begin include("controlsystem.jl") end -# @safetestset "Build Targets Test" begin include("build_targets.jl") end -# @safetestset "Domain Test" begin include("domains.jl") end -# @safetestset "Modelingtoolkitize Test" begin include("modelingtoolkitize.jl") end -# @safetestset "Constraints Test" begin include("constraints.jl") end -# @safetestset "Reduction Test" begin include("reduction.jl") end -# @safetestset "Components Test" begin include("components.jl") end -# @safetestset "PDE Construction Test" begin include("pde.jl") end -# @safetestset "Lowering Integration Test" begin include("lowering_solving.jl") end -# @safetestset "Test Big System Usage" begin include("bigsystem.jl") end -# @safetestset "Depdendency Graph Test" begin include("dep_graphs.jl") end -# @safetestset "Function Registration Test" begin include("function_registration.jl") end +@safetestset "Parsing Test" begin include("variable_parsing.jl") end +@safetestset "Differentiation Test" begin include("derivatives.jl") end +@safetestset "Simplify Test" begin include("simplify.jl") end +@safetestset "Operation Overloads Test" begin include("operation_overloads.jl") end +@safetestset "Direct Usage Test" begin include("direct.jl") end +@safetestset "System Linearity Test" begin include("linearity.jl") end +@safetestset "Build Function Test" begin include("build_function.jl") end +@safetestset "ODESystem Test" begin include("odesystem.jl") end +@safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end +@safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end +@safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end +@safetestset "SDESystem Test" begin include("sdesystem.jl") end +@safetestset "NonlinearSystem Test" begin include("nonlinearsystem.jl") end +@safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end +@safetestset "ReactionSystem Test" begin include("reactionsystem.jl") end +@safetestset "JumpSystem Test" begin include("jumpsystem.jl") end +@safetestset "ControlSystem Test" begin include("controlsystem.jl") end +@safetestset "Build Targets Test" begin include("build_targets.jl") end +@safetestset "Domain Test" begin include("domains.jl") end +@safetestset "Modelingtoolkitize Test" begin include("modelingtoolkitize.jl") end +@safetestset "Constraints Test" begin include("constraints.jl") end +@safetestset "Reduction Test" begin include("reduction.jl") end +@safetestset "Components Test" begin include("components.jl") end +@safetestset "PDE Construction Test" begin include("pde.jl") end +@safetestset "Lowering Integration Test" begin include("lowering_solving.jl") end +@safetestset "Test Big System Usage" begin include("bigsystem.jl") end +@safetestset "Depdendency Graph Test" begin include("dep_graphs.jl") end +@safetestset "Function Registration Test" begin include("function_registration.jl") end @safetestset "Array of Array Test" begin include("build_function_arrayofarray.jl") end @testset "Distributed Test" begin include("distributed.jl") end @safetestset "Variable Utils Test" begin include("variable_utils.jl") end From 41130c8805d92eaeff4b50efbdfdbebc2c2fe875 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 4 Feb 2021 17:27:13 -0500 Subject: [PATCH 23/26] Fix latexify --- src/direct.jl | 46 ++++++++++++++++++++++++++++++----------- src/latexify_recipes.jl | 4 ++-- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/direct.jl b/src/direct.jl index 5b19e11998..14d0ca2454 100644 --- a/src/direct.jl +++ b/src/direct.jl @@ -231,18 +231,36 @@ function sparsehessian(O, vars::AbstractVector; simplify=false) return H end -""" - canonicalexpr(O) -> (canonical::Bool, expr) +# `_toexpr` is only used for latexify +function _toexpr(O; canonicalize=true) + if canonicalize + canonical, O = canonicalexpr(O) + canonical && return O + else + !istree(O) && return O + end + + op = operation(O) + args = arguments(O) + if op isa Differential + ex = _toexpr(args[1]; canonicalize=canonicalize) + wrt = _toexpr(op.x; canonicalize=canonicalize) + return :(_derivative($ex, $wrt)) + elseif op isa Sym + isempty(args) && return nameof(op) + return Expr(:call, _toexpr(op; canonicalize=canonicalize), _toexpr(args; canonicalize=canonicalize)...) + end + return Expr(:call, op, _toexpr(args; canonicalize=canonicalize)...) +end +_toexpr(s::Sym; kw...) = nameof(s) -Canonicalize `O`. Return `canonical` if `expr` is valid code to generate. -""" function canonicalexpr(O) !istree(O) && return true, O op = operation(O) args = arguments(O) if op === (^) if length(args) == 2 && args[2] isa Number && args[2] < 0 - ex = toexpr(args[1]) + ex = _toexpr(args[1]) if args[2] == -1 expr = Expr(:call, inv, ex) else @@ -254,11 +272,15 @@ function canonicalexpr(O) return false, O end -function toexpr(eq::Equation; kw...) - Expr(:(=), toexpr(eq.lhs; kw...), toexpr(eq.rhs; kw...)) -end +for fun in [:toexpr, :_toexpr] + @eval begin + function $fun(eq::Equation; kw...) + Expr(:(=), $fun(eq.lhs; kw...), $fun(eq.rhs; kw...)) + end -toexpr(eqs::AbstractArray; kw...) = map(eq->toexpr(eq; kw...), eqs) -toexpr(x::Integer; kw...) = x -toexpr(x::AbstractFloat; kw...) = x -toexpr(x::Num; kw...) = toexpr(value(x); kw...) + $fun(eqs::AbstractArray; kw...) = map(eq->$fun(eq; kw...), eqs) + $fun(x::Integer; kw...) = x + $fun(x::AbstractFloat; kw...) = x + $fun(x::Num; kw...) = $fun(value(x); kw...) + end +end diff --git a/src/latexify_recipes.jl b/src/latexify_recipes.jl index 28d3322ff8..6c7c4997a9 100644 --- a/src/latexify_recipes.jl +++ b/src/latexify_recipes.jl @@ -11,13 +11,13 @@ prettify_expr(expr::Expr) = Expr(expr.head, prettify_expr.(expr.args)...) # that latexify can deal with rhs = getfield.(eqs, :rhs) - rhs = prettify_expr.(toexpr(rhs; canonicalize=false)) + rhs = prettify_expr.(_toexpr(rhs; canonicalize=false)) rhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in rhs] rhs = [postwalk(x -> x isa Expr && x.args[1] == :_derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in rhs] rhs = [postwalk(x -> x isa Expr && x.args[1] == :_derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in rhs] lhs = getfield.(eqs, :lhs) - lhs = prettify_expr.(toexpr(lhs; canonicalize=false)) + lhs = prettify_expr.(_toexpr(lhs; canonicalize=false)) lhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in lhs] lhs = [postwalk(x -> x isa Expr && x.args[1] == :_derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in lhs] lhs = [postwalk(x -> x isa Expr && x.args[1] == :_derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in lhs] From 096e762455f36892d25780d998bae167f4e3ab43 Mon Sep 17 00:00:00 2001 From: Dan Padilha Date: Fri, 5 Feb 2021 13:27:49 +0900 Subject: [PATCH 24/26] Add work-around for RGF to set the cache module tag appropriately. --- src/build_function.jl | 5 ++++- test/precompile_test.jl | 8 ++------ test/precompile_test/ODEPrecompileTest.jl | 15 ++++++--------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 8d157f200e..1d6035755c 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -101,7 +101,10 @@ function _build_and_inject_function(mod::Module, ex) elseif ex.head == :(->) return _build_and_inject_function(mod, Expr(:function, ex.args...)) end - @RuntimeGeneratedFunction(mod, ex) + # XXX: Workaround to specify the module as both the cache module AND context module. + # Currently, the @RuntimeGeneratedFunction macro only sets the context module. + module_tag = getproperty(mod, RuntimeGeneratedFunctions._tagname) + RuntimeGeneratedFunctions.RuntimeGeneratedFunction(module_tag, module_tag, ex) end # Detect heterogeneous element types of "arrays of matrices/sparce matrices" diff --git a/test/precompile_test.jl b/test/precompile_test.jl index 61abebe55e..40a2c2452c 100644 --- a/test/precompile_test.jl +++ b/test/precompile_test.jl @@ -8,8 +8,7 @@ using ODEPrecompileTest u = collect(1:3) p = collect(4:6) -# This case does not work, because "f_bad" gets defined in ModelingToolkit -# instead of in the compiled module! +# These cases do not work, because they get defined in the ModelingToolkit's RGF cache. @test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_iip).parameters[2]) == ModelingToolkit @test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_oop).parameters[2]) == ModelingToolkit @test parentmodule(typeof(ODEPrecompileTest.f_noeval_bad.f.f_iip).parameters[2]) == ModelingToolkit @@ -17,10 +16,7 @@ p = collect(4:6) @test_throws KeyError ODEPrecompileTest.f_bad(u, p, 0.1) @test_throws KeyError ODEPrecompileTest.f_noeval_bad(u, p, 0.1) -# This case works, because "f_good" gets defined in the precompiled module. -@test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_iip).parameters[2]) == ODEPrecompileTest -@test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_oop).parameters[2]) == ODEPrecompileTest +# This case works, because it gets defined with the appropriate cache and context tags. @test parentmodule(typeof(ODEPrecompileTest.f_noeval_good.f.f_iip).parameters[2]) == ODEPrecompileTest @test parentmodule(typeof(ODEPrecompileTest.f_noeval_good.f.f_oop).parameters[2]) == ODEPrecompileTest -@test ODEPrecompileTest.f_good(u, p, 0.1) == [4, 0, -16] @test ODEPrecompileTest.f_noeval_good(u, p, 0.1) == [4, 0, -16] \ No newline at end of file diff --git a/test/precompile_test/ODEPrecompileTest.jl b/test/precompile_test/ODEPrecompileTest.jl index 453cb0d774..fc8d784d91 100644 --- a/test/precompile_test/ODEPrecompileTest.jl +++ b/test/precompile_test/ODEPrecompileTest.jl @@ -5,7 +5,7 @@ module ODEPrecompileTest # Define some variables @parameters t σ ρ β @variables x(t) y(t) z(t) - @derivatives D'~t + D = Differential(t) # Define a differential equation eqs = [D(x) ~ σ*(y-x), @@ -16,17 +16,14 @@ module ODEPrecompileTest return ODEFunction(de, [x,y,z], [σ,ρ,β]; kwargs...) end - # Build an ODEFunction as part of the module's precompilation. This case - # will not work, because the generated RGFs will be put into - # ModelingToolkit's RGF cache. + # Build an ODEFunction as part of the module's precompilation. These cases + # will not work, because the generated RGFs are put into the ModelingToolkit cache. const f_bad = system() + const f_noeval_bad = system(; eval_expression=false) - # This case will work, because it will be put into our own module's cache. + # Setting eval_expression=false and eval_module=[this module] will ensure + # the RGFs are put into our own cache, initialised below. using RuntimeGeneratedFunctions RuntimeGeneratedFunctions.init(@__MODULE__) - const f_good = system(; eval_module=@__MODULE__) - - # Also test that eval_expression=false works - const f_noeval_bad = system(; eval_expression=false) const f_noeval_good = system(; eval_expression=false, eval_module=@__MODULE__) end \ No newline at end of file From 5128b187c390d9acf8d3c00cfcbe284ba90705c9 Mon Sep 17 00:00:00 2001 From: Dan Padilha Date: Fri, 5 Feb 2021 13:55:19 +0900 Subject: [PATCH 25/26] Move precompiled modules test above the un-safe distributed test. --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8be142e333..8d5bb1f18b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,9 +29,9 @@ using SafeTestsets, Test @safetestset "Depdendency Graph Test" begin include("dep_graphs.jl") end @safetestset "Function Registration Test" begin include("function_registration.jl") end @safetestset "Array of Array Test" begin include("build_function_arrayofarray.jl") end +@safetestset "Precompiled Modules Test" begin include("precompile_test.jl") end @testset "Distributed Test" begin include("distributed.jl") end @safetestset "Variable Utils Test" begin include("variable_utils.jl") end println("Last test requires gcc available in the path!") @safetestset "C Compilation Test" begin include("ccompile.jl") end @safetestset "Latexify recipes Test" begin include("latexify.jl") end -@safetestset "Precompiled Modules Test" begin include("precompile_test.jl") end From b0a67795a11f893aafe2714ef3e3098545604611 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Tue, 9 Feb 2021 18:23:46 +0530 Subject: [PATCH 26/26] set LOADPATH in @everywhere --- test/precompile_test.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/precompile_test.jl b/test/precompile_test.jl index 40a2c2452c..a09830c15a 100644 --- a/test/precompile_test.jl +++ b/test/precompile_test.jl @@ -1,8 +1,11 @@ using Test using ModelingToolkit +using Distributed + # Test that the precompiled ODE system works -push!(LOAD_PATH, joinpath(@__DIR__, "precompile_test")) +@everywhere push!(LOAD_PATH, joinpath(@__DIR__, "precompile_test")) + using ODEPrecompileTest u = collect(1:3)