Skip to content

Commit

Permalink
ChainRulesCore extension
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Mar 4, 2023
1 parent 6f36a05 commit 0835b5a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
10 changes: 9 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
SciMLBaseChainRulesCoreExt = "ChainRulesCore"

[compat]
ArrayInterface = "6, 7"
ChainRulesCore = "1.15"
CommonSolve = "0.2"
ConstructionBase = "1"
DocStringExtensions = "0.8, 0.9"
Expand All @@ -51,6 +58,7 @@ TruncatedStacktraces = "1"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -60,4 +68,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Pkg", "SafeTestsets", "Test", "StaticArrays"]
test = ["Pkg", "SafeTestsets", "Test", "StaticArrays", "ChainRulesCore"]
35 changes: 35 additions & 0 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module SciMLBaseChainRulesCoreExt

using SciMLBase
isdefined(Base, :get_extension) ? (import ChainRulesCore) : (import ..ChainRulesCore)

function ChainRulesCore.rrule(::Type{
<:SciMLBase.PDETimeSeriesSolution{T, N, uType, Disc, Sol, DType, tType, domType, ivType, dvType,
P, A,
IType}}, u,
args...) where {T, N, uType, Disc, Sol, DType, tType, domType, ivType, dvType,
P, A,
IType}
function PDETimeSeriesSolutionAdjoint(ȳ)
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
end

SciMLBase.PDETimeSeriesSolution{T, N, uType, Disc, Sol, DType, tType, domType, ivType, dvType,
P, A,
IType}(u, args...), PDETimeSeriesSolutionAdjoint
end

function ChainRulesCore.rrule(::Type{
<:SciMLBase.PDENoTimeSolution{T, N, uType, Disc, Sol, domType, ivType, dvType, P, A,
IType}}, u,
args...) where {T, N, uType, Disc, Sol, domType, ivType, dvType, P, A,
IType}
function PDENoTimeSolutionAdjoint(ȳ)
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
end

SciMLBase.PDENoTimeSolution{T, N, uType, Disc, Sol, domType, ivType, dvType, P, A,
IType}(u, args...), PDENoTimeSolutionAdjoint
end

end
6 changes: 6 additions & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,12 @@ function wrapfun_oop end
function wrapfun_iip end
function unwrap_fw end

@static if !isdefined(Base, :get_extension)
function __init__()
@require ChainRulesCore="d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin include("../ext/SciMLBaseChainRulesCore.jl") end
end
end

export ReturnCode

export DEAlgorithm, SciMLAlgorithm, DEProblem, DEAlgorithm, DESolution, SciMLSolution
Expand Down

0 comments on commit 0835b5a

Please sign in to comment.