Skip to content

Commit

Permalink
feat: update to Functors v0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 15, 2024
1 parent c18d8ca commit 4048a5f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 59 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqCallbacks"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
authors = ["Chris Rackauckas <[email protected]>"]
version = "4.1.0"
version = "4.2.0"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Expand All @@ -25,7 +25,7 @@ DataStructures = "0.18.13"
DiffEqBase = "6.155.3"
DifferentiationInterface = "0.6.1"
ForwardDiff = "0.10.36"
Functors = "0.4"
Functors = "0.5"
LinearAlgebra = "1.10"
Markdown = "1.10"
NonlinearSolve = "3.14"
Expand Down
100 changes: 43 additions & 57 deletions src/functor_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,64 +6,55 @@
`y[:] .= vec(x)` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_copyto!(y::AbstractArray, x::AbstractArray) = copyto!(y, x)
recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x)
function recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F}
map(recursive_copyto!, values(y), values(x))
recursive_copyto!(y, x) = fmap(internal_copyto!, y, x)

function internal_copyto!(y, x)
hasmethod(copyto!, Tuple{typeof(y), typeof(x)}) ? copyto!(y, x) : nothing
end
recursive_copyto!(y::T, x::T) where {T} = fmap(recursive_copyto!, y, x)
recursive_copyto!(y, ::Nothing) = y
recursive_copyto!(::Nothing, ::Nothing) = nothing

"""
neg!(x)
`x .*= -1` for generic `x`. This is used to handle non-array parameters!
"""
recursive_neg!(x::AbstractArray) = (x .*= -1)
recursive_neg!(x::Tuple) = map(recursive_neg!, x)
recursive_neg!(x::NamedTuple{F}) where {F} = NamedTuple{F}(map(recursive_neg!, values(x)))
recursive_neg!(x) = fmap(recursive_neg!, x)
recursive_neg!(::Nothing) = nothing
recursive_neg!(x) = fmap(internal_neg!, x)

internal_neg!(x::Number) = -x
internal_neg!(x::AbstractArray) = x .*= -1
internal_neg!(x) = nothing

"""
zero!(x)
`x .= 0` for generic `x`. This is used to handle non-array parameters!
"""
recursive_zero!(x::AbstractArray) = (x .= 0)
recursive_zero!(x::Tuple) = map(recursive_zero!, x)
recursive_zero!(x::NamedTuple{F}) where {F} = NamedTuple{F}(map(recursive_zero!, values(x)))
recursive_zero!(x) = fmap(recursive_zero!, x)
recursive_zero!(::Nothing) = nothing
recursive_zero!(x) = fmap(internal_zero!, x)

internal_zero!(x::Number) = zero(x)
internal_zero!(x::AbstractArray) = fill!(x, zero(eltype(x)))
internal_zero!(x) = nothing

"""
recursive_sub!(y, x)
`y .-= x` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_sub!(y::AbstractArray, x::AbstractArray) = axpy!(-1, x, y)
recursive_sub!(y::Tuple, x::Tuple) = map(recursive_sub!, y, x)
function recursive_sub!(y::NamedTuple{F}, x::NamedTuple{F}) where {F}
NamedTuple{F}(map(recursive_sub!, values(y), values(x)))
end
recursive_sub!(y::T, x::T) where {T} = fmap(recursive_sub!, y, x)
recursive_sub!(y, ::Nothing) = y
recursive_sub!(::Nothing, ::Nothing) = nothing
recursive_sub!(y, x) = fmap(internal_sub!, y, x)

internal_sub!(x::Number, y::Number) = x - y
internal_sub!(x::AbstractArray, y::AbstractArray) = axpy!(-1, y, x)
internal_sub!(x, y) = nothing

"""
recursive_add!(y, x)
`y .+= x` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_add!(y::AbstractArray, x::AbstractArray) = y .+= x
recursive_add!(y::Tuple, x::Tuple) = recursive_add!.(y, x)
function recursive_add!(y::NamedTuple{F}, x::NamedTuple{F}) where {F}
NamedTuple{F}(recursive_add!(values(y), values(x)))
end
recursive_add!(y::T, x::T) where {T} = fmap(recursive_add!, y, x)
recursive_add!(y, ::Nothing) = y
recursive_add!(::Nothing, ::Nothing) = nothing
recursive_add!(y, x) = fmap(internal_add!, y, x)

internal_add!(x::Number, y::Number) = x + y
internal_add!(x::AbstractArray, y::AbstractArray) = y .+= x
internal_add!(x, y) = nothing

"""
allocate_vjp(λ, x)
Expand All @@ -88,43 +79,38 @@ allocate_vjp(x) = fmap(allocate_vjp, x)
`zero.(x)` for generic `x`. This is used to handle non-array parameters!
"""
allocate_zeros(x::AbstractArray) = zero.(x)
allocate_zeros(x::Tuple) = allocate_zeros.(x)
allocate_zeros(x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_zeros.(values(x)))
allocate_zeros(x) = fmap(allocate_zeros, x)
allocate_zeros(x) = fmap(internal_allocate_zeros, x)

internal_allocate_zeros(x) = hasmethod(zero, Tuple{typeof(x)}) ? zero(x) : nothing

"""
recursive_copy(y)
`copy(y)` for generic `y`. This is used to handle non-array parameters!
"""
recursive_copy(y::AbstractArray) = copy(y)
recursive_copy(y::Tuple) = recursive_copy.(y)
recursive_copy(y::NamedTuple{F}) where {F} = NamedTuple{F}(recursive_copy.(values(y)))
recursive_copy(y) = fmap(recursive_copy, y)
recursive_copy(y) = fmap(internal_copy, y)

internal_copy(x) = hasmethod(copy, Tuple{typeof(x)}) ? copy(x) : nothing

"""
recursive_adjoint(y)
`adjoint(y)` for generic `y`. This is used to handle non-array parameters!
"""
recursive_adjoint(y::AbstractArray) = adjoint(y)
recursive_adjoint(y::Tuple) = recursive_adjoint.(y)
recursive_adjoint(y::NamedTuple{F}) where {F} = NamedTuple{F}(recursive_adjoint.(values(y)))
recursive_adjoint(y) = fmap(recursive_adjoint, y)
recursive_adjoint(y) = fmap(internal_adjoint, y)

internal_adjoint(x) = hasmethod(adjoint, Tuple{typeof(x)}) ? adjoint(x) : nothing

# scalar_mul!
recursive_scalar_mul!(x::AbstractArray, α) = x .*= α
recursive_scalar_mul!(x::Tuple, α) = recursive_scalar_mul!.(x, α)
function recursive_scalar_mul!(x::NamedTuple{F}, α) where {F}
return NamedTuple{F}(recursive_scalar_mul!(values(x), α))
end
recursive_scalar_mul!(x, α) = fmap(Base.Fix1(recursive_scalar_mul!, α), x)
recursive_scalar_mul!(x, α) = fmap(Base.Fix2(internal_scalar_mul!, α), x)

internal_scalar_mul!(x::Number, α) = x * α
internal_scalar_mul!(x::AbstractArray, α) = x .*= α
internal_scalar_mul!(x, α) = nothing

# axpy!
recursive_axpy!(α, x::AbstractArray, y::AbstractArray) = axpy!(α, x, y)
recursive_axpy!(α, x::Tuple, y::Tuple) = recursive_axpy!.(α, x, y)
function recursive_axpy!(α, x::NamedTuple{F}, y::NamedTuple{F}) where {F}
return NamedTuple{F}(recursive_axpy!(α, values(x), values(y)))
end
recursive_axpy!(α, x, y) = fmap(Base.Fix1(recursive_axpy!, α), x, y)
recursive_axpy!(α, x, y) = fmap((xᵢ, yᵢ) -> internal_axpy!(α, xᵢ, yᵢ), x, y)

internal_axpy!(α, x::Number, y::Number) = y + α * x
internal_axpy!(α, x::AbstractArray, y::AbstractArray) = axpy!(α, x, y)
internal_axpy!(α, x, y) = nothing

0 comments on commit 4048a5f

Please sign in to comment.