Skip to content

Commit

Permalink
Merge pull request #86 from JuliaGNI/update-dependencies
Browse files Browse the repository at this point in the history
Update dependencies
  • Loading branch information
michakraus authored Nov 7, 2023
2 parents 6c41e71 + d04b09a commit 9445c71
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 25 deletions.
18 changes: 9 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,24 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractNeuralNetworks = "0.1"
BandedMatrices = "0.17"
CUDA = "4"
BandedMatrices = "0.17, 1"
CUDA = "4, 5"
ChainRules = "1"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
Distances = "0.10"
Documenter = "0.27"
Documenter = "0.27, 1"
ForwardDiff = "0.10"
GPUArrays = "8"
GeometricBase = "0.8"
GeometricEquations = "0.12"
GeometricIntegrators = "0.12"
HDF5 = "0.16"
GPUArrays = "8, 9"
GeometricBase = "0.9"
GeometricEquations = "0.14"
GeometricIntegrators = "0.13"
HDF5 = "0.16, 0.17"
KernelAbstractions = "0.9"
Lux = "0.4, 0.5"
NLsolve = "4"
NNlib = "0.8, 0.9"
Optimisers = "0.2"
Optimisers = "0.2, 0.3"
ProgressMeter = "1"
SafeTestsets = "0.1"
StatsBase = "0.33, 0.34"
Expand Down
2 changes: 1 addition & 1 deletion src/GeometricMachineLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ module GeometricMachineLearning
import AbstractNeuralNetworks: params, architecture, model, dim
# export params, architetcure, model
export dim
import GeometricIntegrators.Integrators: method
import GeometricIntegrators.Integrators: method, GeometricIntegrator
import NNlib: σ, sigmoid, softmax
#import LogExpFunctions: softmax

Expand Down
2 changes: 1 addition & 1 deletion src/data_loader/data_loader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ For drawing the batch, the sampling is done over n_params and n_time_steps (here
If for the output we have a tensor whose second axis has length 1, we still store it as a tensor and not a matrix. This is because it is not necessarily of length 1.
TODO: Implement DataLoader that works well with GeometricEnsembles etc.
TODO: Implement DataLoader that works well with EnsembleProblems etc.
"""
struct DataLoader{T, AT<:AbstractArray{T}, OT<:Union{AbstractArray, Nothing}}
input::AT
Expand Down
2 changes: 1 addition & 1 deletion src/integrator/problem_hnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function HNNProblem(nn::NeuralNetwork{<:HamiltonianNeuralNetwork}, tspan, tstep,
HODEProblem(v, f, hamiltonian, tspan, tstep, ics; kwargs...)
end

function HNNProblem(nn::NeuralNetwork{<:HamiltonianNeuralNetwork}, tspan, tstep, q₀::State, p₀::State; kwargs...)
function HNNProblem(nn::NeuralNetwork{<:HamiltonianNeuralNetwork}, tspan, tstep, q₀::StateVariable, p₀::StateVariable; kwargs...)
ics = (q = q₀, p = p₀)
HNNProblem(nn, tspan, tstep, ics; kwargs...)
end
Expand Down
4 changes: 2 additions & 2 deletions src/integrator/problem_lnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ function LNNProblem(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, tspan::Tuple,
LNNProblem(nn, GeometricEquations._lode_default_g, tspan, tstep, ics; kwargs...)
end

function LNNProblem(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, g, tspan::Tuple, tstep::Real, q₀::State, p₀::State, λ₀::State = zero(q₀); kwargs...)
function LNNProblem(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, g, tspan::Tuple, tstep::Real, q₀::StateVariable, p₀::StateVariable, λ₀::StateVariable = zero(q₀); kwargs...)
ics = (q = q₀, p = p₀, λ = λ₀)
LNNProblem(nn, g, tspan, tstep, ics; kwargs...)
end

function LNNProblem(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, tspan::Tuple, tstep::Real, q₀::State, p₀::State, λ₀::State = zero(q₀); kwargs...)
function LNNProblem(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, tspan::Tuple, tstep::Real, q₀::StateVariable, p₀::StateVariable, λ₀::StateVariable = zero(q₀); kwargs...)
LNNProblem(nn, GeometricEquations._lode_default_g, tspan, tstep, q₀, p₀, λ₀; kwargs...)
end
7 changes: 2 additions & 5 deletions src/integrator/sympnet_integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@ function method(nns::NeuralNetSolution{<: NeuralNetwork{<:SympNet}})
end


const IntegratorSympNet{DT,TT} = Integrator{<:Union{HODEProblem{DT,TT}}, <:SympNetMethod}


function GeometricIntegrators.integrate(nns::NeuralNetSolution; kwargs...)
function GeometricIntegrators.Integrators.integrate(nns::NeuralNetSolution; kwargs...)
integrate(problem(nns), method(nns); kwargs...)
end


function GeometricIntegrators.integrate_step!(int::IntegratorSympNet)
function GeometricIntegrators.Integrators.integrate_step!(int::GeometricIntegrator{<:SympNetMethod, <:AbstractProblemPODE})

# compute how may times to compose nn ()
@assert GeometricIntegrators.Integrators.method(int).Δt % GeometricIntegrators.timestep(int) == 0
Expand Down
2 changes: 1 addition & 1 deletion src/nnsolution/neural_net_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ update_history(nns::NeuralNetSolution, sg::SingleHistory) = _add(nns.history, sg

@inline nn(nns::NeuralNetSolution) = nns.nn
@inline problem(nns::NeuralNetSolution) = nns.problem
@inline problem(nns::NeuralNetSolution{T,<:GeometricEnsemble} where T) = GeometricEquations.problem(nns.problem,1)
@inline problem(nns::NeuralNetSolution{T,<:EnsembleProblem} where T) = GeometricEquations.problem(nns.problem,1)
@inline GeometricBase.tstep(nns::NeuralNetSolution) = nns.tstep
@inline loss(nns::NeuralNetSolution) = nns.loss

Expand Down
5 changes: 0 additions & 5 deletions test/integrator/test_integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,3 @@ training_parameters = TrainingParameters(nruns, method, mopt; batch_size = batch
neural_net_solution = train!(neuralnet, training_data, training_parameters)

@testnoerror sol = integrate(neural_net_solution)





0 comments on commit 9445c71

Please sign in to comment.