Skip to content

Commit

Permalink
Add derivatives API
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed May 17, 2024
1 parent 39d0c76 commit 3770f8b
Showing 1 changed file with 44 additions and 55 deletions.
99 changes: 44 additions & 55 deletions src/derivative.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@

export derivative, derivative!
export derivative, derivative!, derivatives, make_seed

"""
derivative(f, x, order::Int64)
derivative(f, x, l, order::Int64)
Wrapper functions for converting order from a number to a type. Actual APIs are detailed below:
derivative(f, x::T, ::Val{N})
Computes `order`-th derivative of `f` w.r.t. scalar `x`.
derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N})
derivative(f, x, l, ::Val{N})
derivative(f!, y, x, l, ::Val{N})
Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
derivative(f, x::AbstractMatrix{T}, ::Val{N})
derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N})
Batch mode derivative / directional derivative calculations, where each column of `x` represents a scalar or a vector. `f` is expected to accept matrices as input.
- For a M-by-N matrix, calculate the directional derivative for each column.
- For a 1-by-N matrix (row vector), calculate the derivative for each scalar.
"""
function derivative end

Expand All @@ -32,54 +17,58 @@ In-place derivative calculation APIs. `result` is expected to be pre-allocated a
"""
function derivative! end

"""
derivatives(f, x, l, ::Val{N})
derivatives(f!, y, x, l, ::Val{N})
Computes all derivatives of `f` at `x` up to order `N - 1`.
"""
function derivatives end

# Convenience wrapper for adding unit seed to the input

@inline derivative(f, x, order::Int64) = derivative(f, x, one(eltype(x)), order)

# Convenience wrappers for converting orders to value types
# and forward work to core APIs

@inline derivative(f, x, order::Int64) = derivative(f, x, one(eltype(x)), order)
@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order + 1}())
@inline derivative(f!, y, x, l, order::Int64) = derivative(f!, y, x, l, Val{order + 1}())
@inline derivative!(result, f, x, l, order::Int64) = derivative!(
result, f, x, l, Val{order + 1}())
@inline derivative!(result, f!, y, x, l, order::Int64) = derivative!(
result, f!, y, x, l, Val{order + 1}())

# Core APIs

# Added to help Zygote infer types
@inline function make_taylor(x::T, l::S, ::Val{N}) where {T <: TN, S <: TN, N}
@inline function make_seed(x::T, l::S, ::Val{N}) where {T <: TN, S <: TN, N}
TaylorScalar{T, N}(x, convert(T, l))
end

@inline function make_taylor(x::AbstractArray{T}, l, vN::Val{N}) where {T <: TN, N}
broadcast(make_taylor, x, l, vN)
end

# Out-of-place function, out-of-place derivative
@inline function derivative(f, x, l, vN::Val{N}) where {N}
t = make_taylor(x, l, vN)
return extract_derivative(f(t), N)
end

# Below three advanced APIs do not have convenience wrappers

# In-place function, out-of-place derivative
@inline function derivative(f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
s = similar(y, TaylorScalar{T, N})
t = make_taylor(x, l, vN)
f!(s, t)
map!(primal, y, s)
return extract_derivative(s, N)
end

# Out-of-place function, in-place derivative
@inline function derivative!(result, f, x, l, vN::Val{N}) where {N}
t = make_taylor(x, l, vN)
s = f(t)
extract_derivative!(result, s, N)
return result
@inline function make_seed(x::AbstractArray{T}, l, vN::Val{N}) where {T <: TN, N}
broadcast(make_seed, x, l, vN)
end

# In-place function, in-place derivative
@inline function derivative!(result, f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
s = similar(y, TaylorScalar{T, N})
t = make_taylor(x, l, vN)
f!(s, t)
map!(primal, y, s)
extract_derivative!(result, s, N)
return result
# `derivative` API: computes the `N - 1`-th derivative of `f` at `x`
@inline derivative(f, x, l, vN::Val{N}) where {N} = extract_derivative(
derivatives(f, x, l, vN), N)
@inline derivative(f!, y, x, l, vN::Val{N}) where {N} = extract_derivative(
derivatives(f!, y, x, l, vN), N)
@inline derivative!(result, f, x, l, vN::Val{N}) where {N} = extract_derivative!(
result, derivatives(f, x, l, vN), N)
@inline derivative!(result, f!, y, x, l, vN::Val{N}) where {N} = extract_derivative!(
result, derivatives(f!, y, x, l, vN), N)

# `derivatives` API: computes all derivatives of `f` at `x` up to order `N - 1`

# Out-of-place function
@inline derivatives(f, x, l, vN::Val{N}) where {N} = f(make_seed(x, l, vN))

# In-place function
@inline function derivatives(f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
buffer = similar(y, TaylorScalar{T, N})
f!(buffer, make_seed(x, l, vN))
map!(primal, y, buffer)
return buffer
end

0 comments on commit 3770f8b

Please sign in to comment.