Skip to content

Commit

Permalink
refactor: deprecate uses of HamiltonianNN and NeuralHamiltonianDE
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 28, 2024
1 parent 88fa06d commit 7f101e3
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 133 deletions.
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Expand All @@ -21,7 +19,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1.5"
Expand All @@ -41,7 +38,6 @@ DistributionsAD = "0.6"
ExplicitImports = "1.9"
Flux = "0.14.15"
ForwardDiff = "0.10"
Functors = "0.4"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
Expand Down Expand Up @@ -78,6 +74,7 @@ DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -97,6 +94,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BenchmarkTools", "ComponentArrays", "DelayDiffEq", "DiffEqCallbacks", "Distances", "Distributed", "ExplicitImports", "Flux", "Hwloc", "InteractiveUtils", "LuxCUDA", "MLDatasets", "NNlib", "OneHotArrays", "Optimisers", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEq", "Printf", "Random", "ReTestItems", "Reexport", "Statistics", "StochasticDiffEq", "Test"]
test = ["Aqua", "BenchmarkTools", "ComponentArrays", "DelayDiffEq", "DiffEqCallbacks", "Distances", "Distributed", "ExplicitImports", "ForwardDiff", "Flux", "Hwloc", "InteractiveUtils", "LuxCUDA", "MLDatasets", "NNlib", "OneHotArrays", "Optimisers", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEq", "Printf", "Random", "ReTestItems", "Reexport", "Statistics", "StochasticDiffEq", "Test", "Zygote"]
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ explore various ways to integrate the two methodologies:
- `TensorLayer` has been removed, use `Boltz.Layers.TensorProductLayer` instead.
- Basis functions in DiffEqFlux have been removed in favor of `Boltz.Basis` module.
- `SplineLayer` has been removed, use `Boltz.Layers.SplineLayer` instead.
- `NeuralHamiltonianDE` has been removed, use `NeuralODE` with `Layers.HamiltonianNN` instead.
- `HamiltonianNN` has been removed in favor of `Layers.HamiltonianNN`.

### v3

Expand Down
9 changes: 0 additions & 9 deletions docs/src/layers/HamiltonianNN.md

This file was deleted.

8 changes: 3 additions & 5 deletions src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ using ConcreteStructs: @concrete
using DataInterpolations: DataInterpolations
using Distributions: Distributions, ContinuousMultivariateDistribution, Distribution, logpdf
using DistributionsAD: DistributionsAD
using ForwardDiff: ForwardDiff
using Functors: Functors, fmap
using LinearAlgebra: LinearAlgebra, Diagonal, det, tr, mul!
using Lux: Lux, Chain, Dense, StatefulLuxLayer, FromFluxAdaptor
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
Expand All @@ -23,15 +21,16 @@ using SciMLSensitivity: SciMLSensitivity, AdjointLSS, BacksolveAdjoint, EnzymeVJ
NILSS, QuadratureAdjoint, ReverseDiffAdjoint, ReverseDiffVJP,
SteadyStateAdjoint, TrackerAdjoint, TrackerVJP, ZygoteAdjoint,
ZygoteVJP
using Zygote: Zygote

const CRC = ChainRulesCore

@reexport using ADTypes, Lux, Boltz

fixed_state_type(_) = true
fixed_state_type(::Layers.HamiltonianNN{FST}) where {FST} = FST

include("ffjord.jl")
include("neural_de.jl")
include("hnn.jl")

include("collocation.jl")
include("multiple_shooting.jl")
Expand All @@ -40,7 +39,6 @@ include("deprecated.jl")

export NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, AugmentedNDELayer,
NeuralODEMM
export NeuralHamiltonianDE, HamiltonianNN
export FFJORD, FFJORDDistribution
export DimMover

Expand Down
18 changes: 16 additions & 2 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Formerly TensorLayer
# Tensor Layer
Base.@deprecate TensorProductBasisFunction(f, n) Basis.GeneralBasisFunction{:none}(f, n, 1)

for B in (:Chebyshev, :Sin, :Cos, :Fourier, :Legendre, :Polynomial)
Expand All @@ -9,7 +9,7 @@ end
Base.@deprecate TensorLayer(model, out_dim::Int, init_p::F = randn) where {F <: Function} Boltz.Layers.TensorProductLayer(
model, out_dim; init_weight = init_p)

# Formerly SplineLayer
# Spline Layer
function SplineLayer(tspan, tstep, spline_basis; init_saved_points::F = nothing) where {F}
Base.depwarn(
"SplineLayer is deprecated and will be removed in the next major release. Refer to \
Expand All @@ -31,3 +31,17 @@ function SplineLayer(tspan, tstep, spline_basis; init_saved_points::F = nothing)
end

export SplineLayer

# Hamiltonian Neural Network
Base.@deprecate HamiltonianNN(model; ad = AutoZygote()) Layers.HamiltonianNN{true}(
model; autodiff = ad)

function NeuralHamiltonianDE(model, tspan, args...; ad = AutoForwardDiff(), kwargs...)
Base.depwarn(
"NeuralHamiltonianDE is deprecated, use `NeuralODE` with `Layers.HamiltonianNN` instead.",
:NeuralHamiltonianDE)
hnn = model isa Layers.HamiltonianNN ? model : HamiltonianNN(model; ad)
return NeuralODE(hnn, tspan, args, kwargs)
end

export NeuralHamiltonianDE
14 changes: 2 additions & 12 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function __forward_ffjord(n::FFJORD, x::AbstractArray{T, N}, ps, st) where {T, N
(; regularize, monte_carlo) = st
sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP())

model = StatefulLuxLayer{true}(n.model, nothing, st.model)
model = StatefulLuxLayer{fixed_state_type(n.model)}(n.model, nothing, st.model)

ffjord(u, p, t) = __ffjord(model, u, p, n.ad, regularize, monte_carlo)

Expand Down Expand Up @@ -216,17 +216,7 @@ Arguments:
end

Base.length(d::FFJORDDistribution) = prod(d.model.input_dims)
Base.eltype(d::FFJORDDistribution) = __eltype(d.ps)

__eltype(x::AbstractArray) = eltype(x)
function __eltype(x)
T = Ref(Bool)
fmap(x) do x_
T[] = promote_type(T[], __eltype(x_))
return x_
end
return T[]
end
Base.eltype(d::FFJORDDistribution) = Lux.recursive_eltype(d.ps)

function Distributions._logpdf(d::FFJORDDistribution, x::AbstractVector)
return first(first(__forward_ffjord(d.model, reshape(x, :, 1), d.ps, d.st)))
Expand Down
92 changes: 0 additions & 92 deletions src/hnn.jl

This file was deleted.

18 changes: 10 additions & 8 deletions src/neural_de.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function NeuralODE(model, tspan, args...; kwargs...)
end

function (n::NeuralODE)(x, p, st)
model = StatefulLuxLayer{true}(n.model, nothing, st)
model = StatefulLuxLayer{fixed_state_type(n.model)}(n.model, nothing, st)

dudt(u, p, t) = model(u, p)
ff = ODEFunction{false}(dudt; tgrad = basic_tgrad)
Expand Down Expand Up @@ -93,8 +93,9 @@ function NeuralDSDE(drift, diffusion, tspan, args...; kwargs...)
end

function (n::NeuralDSDE)(x, p, st)
drift = StatefulLuxLayer{true}(n.drift, nothing, st.drift)
diffusion = StatefulLuxLayer{true}(n.diffusion, nothing, st.diffusion)
drift = StatefulLuxLayer{fixed_state_type(n.drift)}(n.drift, nothing, st.drift)
diffusion = StatefulLuxLayer{fixed_state_type(n.diffusion)}(
n.diffusion, nothing, st.diffusion)

dudt(u, p, t) = drift(u, p.drift)
g(u, p, t) = diffusion(u, p.diffusion)
Expand Down Expand Up @@ -143,8 +144,9 @@ function NeuralSDE(drift, diffusion, tspan, nbrown, args...; kwargs...)
end

function (n::NeuralSDE)(x, p, st)
drift = StatefulLuxLayer{true}(n.drift, p.drift, st.drift)
diffusion = StatefulLuxLayer{true}(n.diffusion, p.diffusion, st.diffusion)
drift = StatefulLuxLayer{fixed_state_type(n.drift)}(n.drift, p.drift, st.drift)
diffusion = StatefulLuxLayer{fixed_state_type(n.diffusion)}(
n.diffusion, p.diffusion, st.diffusion)

dudt(u, p, t) = drift(u, p.drift)
g(u, p, t) = diffusion(u, p.diffusion)
Expand Down Expand Up @@ -196,7 +198,7 @@ function NeuralCDDE(model, tspan, hist, lags, args...; kwargs...)
end

function (n::NeuralCDDE)(x, ps, st)
model = StatefulLuxLayer{true}(n.model, nothing, st)
model = StatefulLuxLayer{fixed_state_type(n.model)}(n.model, nothing, st)

function dudt(u, h, p, t)
xs = mapfoldl(lag -> h(p, t - lag), vcat, n.lags)
Expand Down Expand Up @@ -249,7 +251,7 @@ end

function (n::NeuralDAE)(u_du::Tuple, p, st)
u0, du0 = u_du
model = StatefulLuxLayer{true}(n.model, nothing, st)
model = StatefulLuxLayer{fixed_state_type(n.model)}(n.model, nothing, st)

function f(du, u, p, t)
nn_out = model(vcat(u, du), p)
Expand Down Expand Up @@ -322,7 +324,7 @@ function NeuralODEMM(model, constraints_model, tspan, mass_matrix, args...; kwar
end

function (n::NeuralODEMM)(x, ps, st)
model = StatefulLuxLayer{true}(n.model, nothing, st)
model = StatefulLuxLayer{fixed_state_type(n.model)}(n.model, nothing, st)

function f(u, p, t)
nn_out = model(u, p)
Expand Down

0 comments on commit 7f101e3

Please sign in to comment.