Skip to content

Commit

Permalink
feat: update to Functors v0.5 (#237)
Browse files Browse the repository at this point in the history
* feat: update to Functors v0.5

* fix: incorrect functors defn
  • Loading branch information
avik-pal authored Nov 15, 2024
1 parent c18d8ca commit 6cbe81b
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 73 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
version: ['1']
version: ['1.10']
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqCallbacks"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
authors = ["Chris Rackauckas <[email protected]>"]
version = "4.1.0"
version = "4.2.0"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Expand All @@ -25,10 +25,10 @@ DataStructures = "0.18.13"
DiffEqBase = "6.155.3"
DifferentiationInterface = "0.6.1"
ForwardDiff = "0.10.36"
Functors = "0.4"
Functors = "0.5"
LinearAlgebra = "1.10"
Markdown = "1.10"
NonlinearSolve = "3.14"
NonlinearSolve = "3.14, 4"
ODEProblemLibrary = "0.1.8"
OrdinaryDiffEq = "6.88"
QuadGK = "2.9"
Expand Down
1 change: 1 addition & 0 deletions docs/src/projection.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ prob = ODEProblem(f, u0, (0.0, 100.0))
```

!!! note

Note that NonlinearSolve.jl is required to be imported for ManifoldProjection

However, this problem is supposed to conserve energy, and thus we define our manifold
Expand Down
108 changes: 41 additions & 67 deletions src/functor_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,125 +6,99 @@
`y[:] .= vec(x)` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_copyto!(y::AbstractArray, x::AbstractArray) = copyto!(y, x)
recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x)
function recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F}
map(recursive_copyto!, values(y), values(x))
recursive_copyto!(y, x) = fmap(internal_copyto!, y, x)

function internal_copyto!(y, x)
hasmethod(copyto!, Tuple{typeof(y), typeof(x)}) ? copyto!(y, x) : nothing
end
recursive_copyto!(y::T, x::T) where {T} = fmap(recursive_copyto!, y, x)
recursive_copyto!(y, ::Nothing) = y
recursive_copyto!(::Nothing, ::Nothing) = nothing

"""
neg!(x)
`x .*= -1` for generic `x`. This is used to handle non-array parameters!
"""
recursive_neg!(x::AbstractArray) = (x .*= -1)
recursive_neg!(x::Tuple) = map(recursive_neg!, x)
recursive_neg!(x::NamedTuple{F}) where {F} = NamedTuple{F}(map(recursive_neg!, values(x)))
recursive_neg!(x) = fmap(recursive_neg!, x)
recursive_neg!(::Nothing) = nothing
recursive_neg!(x) = fmap(internal_neg!, x)

internal_neg!(x::AbstractArray) = x .*= -1
internal_neg!(x) = nothing

"""
zero!(x)
`x .= 0` for generic `x`. This is used to handle non-array parameters!
"""
recursive_zero!(x::AbstractArray) = (x .= 0)
recursive_zero!(x::Tuple) = map(recursive_zero!, x)
recursive_zero!(x::NamedTuple{F}) where {F} = NamedTuple{F}(map(recursive_zero!, values(x)))
recursive_zero!(x) = fmap(recursive_zero!, x)
recursive_zero!(::Nothing) = nothing
recursive_zero!(x) = fmap(internal_zero!, x)

internal_zero!(x::AbstractArray) = fill!(x, false)
internal_zero!(x) = nothing

"""
recursive_sub!(y, x)
`y .-= x` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_sub!(y::AbstractArray, x::AbstractArray) = axpy!(-1, x, y)
recursive_sub!(y::Tuple, x::Tuple) = map(recursive_sub!, y, x)
function recursive_sub!(y::NamedTuple{F}, x::NamedTuple{F}) where {F}
NamedTuple{F}(map(recursive_sub!, values(y), values(x)))
end
recursive_sub!(y::T, x::T) where {T} = fmap(recursive_sub!, y, x)
recursive_sub!(y, ::Nothing) = y
recursive_sub!(::Nothing, ::Nothing) = nothing
recursive_sub!(y, x) = fmap(internal_sub!, y, x)

internal_sub!(y::AbstractArray, x::AbstractArray) = axpy!(-1, x, y)
internal_sub!(y, x) = nothing

"""
recursive_add!(y, x)
`y .+= x` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_add!(y::AbstractArray, x::AbstractArray) = y .+= x
recursive_add!(y::Tuple, x::Tuple) = recursive_add!.(y, x)
function recursive_add!(y::NamedTuple{F}, x::NamedTuple{F}) where {F}
NamedTuple{F}(recursive_add!(values(y), values(x)))
end
recursive_add!(y::T, x::T) where {T} = fmap(recursive_add!, y, x)
recursive_add!(y, ::Nothing) = y
recursive_add!(::Nothing, ::Nothing) = nothing
recursive_add!(y, x) = fmap(internal_add!, y, x)

internal_add!(y::AbstractArray, x::AbstractArray) = y .+= x
internal_add!(y, x) = nothing

"""
allocate_vjp(λ, x)
allocate_vjp(x)
`similar(λ, size(x))` for generic `x`. This is used to handle non-array parameters!
"""
allocate_vjp::AbstractArray, x::AbstractArray) = similar(λ, size(x))
allocate_vjp::AbstractArray, x::Tuple) = allocate_vjp.((λ,), x)
function allocate_vjp::AbstractArray, x::NamedTuple{F}) where {F}
NamedTuple{F}(allocate_vjp.((λ,), values(x)))
end
allocate_vjp::AbstractArray, x) = fmap(Base.Fix1(allocate_vjp, λ), x)
allocate_vjp::AbstractArray, x) = fmap(Base.Fix1(allocate_vjp_internal, λ), x)
allocate_vjp(x) = fmap(similar, x)

allocate_vjp(x::AbstractArray) = similar(x)
allocate_vjp(x::Tuple) = allocate_vjp.(x)
allocate_vjp(x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_vjp.(values(x)))
allocate_vjp(x) = fmap(allocate_vjp, x)
allocate_vjp_internal::AbstractArray, x) = similar(λ, size(x))

"""
allocate_zeros(x)
`zero.(x)` for generic `x`. This is used to handle non-array parameters!
"""
allocate_zeros(x::AbstractArray) = zero.(x)
allocate_zeros(x::Tuple) = allocate_zeros.(x)
allocate_zeros(x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_zeros.(values(x)))
allocate_zeros(x) = fmap(allocate_zeros, x)
allocate_zeros(x) = fmap(internal_allocate_zeros, x)

internal_allocate_zeros(x) = hasmethod(zero, Tuple{typeof(x)}) ? zero(x) : nothing

"""
recursive_copy(y)
`copy(y)` for generic `y`. This is used to handle non-array parameters!
"""
recursive_copy(y::AbstractArray) = copy(y)
recursive_copy(y::Tuple) = recursive_copy.(y)
recursive_copy(y::NamedTuple{F}) where {F} = NamedTuple{F}(recursive_copy.(values(y)))
recursive_copy(y) = fmap(recursive_copy, y)
recursive_copy(y) = fmap(internal_copy, y)

internal_copy(x) = hasmethod(copy, Tuple{typeof(x)}) ? copy(x) : nothing

"""
recursive_adjoint(y)
`adjoint(y)` for generic `y`. This is used to handle non-array parameters!
"""
recursive_adjoint(y::AbstractArray) = adjoint(y)
recursive_adjoint(y::Tuple) = recursive_adjoint.(y)
recursive_adjoint(y::NamedTuple{F}) where {F} = NamedTuple{F}(recursive_adjoint.(values(y)))
recursive_adjoint(y) = fmap(recursive_adjoint, y)
recursive_adjoint(y) = fmap(internal_adjoint, y)

internal_adjoint(x) = hasmethod(adjoint, Tuple{typeof(x)}) ? adjoint(x) : nothing

# scalar_mul!
recursive_scalar_mul!(x::AbstractArray, α) = x .*= α
recursive_scalar_mul!(x::Tuple, α) = recursive_scalar_mul!.(x, α)
function recursive_scalar_mul!(x::NamedTuple{F}, α) where {F}
return NamedTuple{F}(recursive_scalar_mul!(values(x), α))
end
recursive_scalar_mul!(x, α) = fmap(Base.Fix1(recursive_scalar_mul!, α), x)
recursive_scalar_mul!(x, α) = fmap(Base.Fix2(internal_scalar_mul!, α), x)

internal_scalar_mul!(x::Number, α) = x * α
internal_scalar_mul!(x::AbstractArray, α) = x .*= α
internal_scalar_mul!(x, α) = nothing

# axpy!
recursive_axpy!(α, x::AbstractArray, y::AbstractArray) = axpy!(α, x, y)
recursive_axpy!(α, x::Tuple, y::Tuple) = recursive_axpy!.(α, x, y)
function recursive_axpy!(α, x::NamedTuple{F}, y::NamedTuple{F}) where {F}
return NamedTuple{F}(recursive_axpy!(α, values(x), values(y)))
end
recursive_axpy!(α, x, y) = fmap(Base.Fix1(recursive_axpy!, α), x, y)
recursive_axpy!(α, x, y) = fmap((xᵢ, yᵢ) -> internal_axpy!(α, xᵢ, yᵢ), x, y)

internal_axpy!(α, x::AbstractArray, y::AbstractArray) = axpy!(α, x, y)
internal_axpy!(α, x, y) = nothing
2 changes: 1 addition & 1 deletion src/integrating_sum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function (affect!::SavingIntegrandSumAffect)(integrator)
if DiffEqBase.isinplace(integrator.sol.prob)
curu = first(get_tmp_cache(integrator))
integrator(curu, t_temp)
if affect!.integrand_cache == nothing
if affect!.integrand_cache === nothing
recursive_axpy!(gauss_weights[n][i],
affect!.integrand_func(curu, t_temp, integrator), accumulation_cache)
else
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
const GROUP = get(ENV, "GROUP", "All")

# write your own tests here
@time begin
@time @testset "DiffEqCallbacks" begin
if GROUP == "QA"
@time @testset "Quality Assurance" begin
include("qa.jl")
Expand Down

2 comments on commit 6cbe81b

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119542

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v4.2.0 -m "<description of version>" 6cbe81b57dc112631f96b1885e8c2fd0050ccff6
git push origin v4.2.0

Please sign in to comment.